diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-01-26 23:18:06 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-01-26 23:18:06 +0100 |
commit | 68640a70a769568fd571e50322d6e3fa40c78271 (patch) | |
tree | bf888f83b72ff3ef09edecb35ad41c7839f3733c /text_recognizer/networks/transformer | |
parent | 53a4af4ca22ced165fde14cd0de46e29aab7d80d (diff) |
fix: refactor AttentionLayers
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 54 |
1 files changed, 25 insertions, 29 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index fc32f20..4263f52 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -2,7 +2,6 @@ from copy import deepcopy from typing import Any, Optional, Tuple, Type -import attr from torch import nn, Tensor from text_recognizer.networks.transformer.attention import Attention @@ -10,26 +9,24 @@ from text_recognizer.networks.transformer.mlp import FeedForward from text_recognizer.networks.transformer.residual import Residual -@attr.s(eq=False) class AttentionLayers(nn.Module): """Standard transfomer layer.""" - def __attrs_pre_init__(self) -> None: + def __init__( + self, + depth: int, + self_attn: Attention, + norm: Type[nn.Module], + ff: FeedForward, + cross_attn: Optional[Attention] = None, + pre_norm: bool = True, + has_pos_emb: bool = True, + ) -> None: super().__init__() - - depth: int = attr.ib() - self_attn: Attention = attr.ib() - norm: Type[nn.Module] = attr.ib() - ff: FeedForward = attr.ib() - cross_attn: Optional[Attention] = attr.ib(default=None) - pre_norm: bool = attr.ib(default=True) - has_pos_emb: bool = attr.ib(default=False) - layer_types: Tuple[str, ...] = attr.ib(init=False) - layers: nn.ModuleList = attr.ib(init=False) - - def __attrs_post_init__(self) -> None: - self.layer_types = self._get_layer_types() * self.depth - self.layers = self._build() + self.pre_norm = pre_norm + self.has_pos_emb = has_pos_emb + self.layer_types = self._get_layer_types() * depth + self.layers = self._build(self_attn, norm, ff, cross_attn) def _get_layer_types(self) -> Tuple: """Get layer specification.""" @@ -37,24 +34,23 @@ class AttentionLayers(nn.Module): return "a", "c", "f" return "a", "f" - def _delete(self) -> None: - del self.self_attn - del self.ff - del self.norm - del self.cross_attn - - def _build(self) -> nn.ModuleList: + def _build( + self, + self_attn: Attention, + norm: Type[nn.Module], + ff: FeedForward, + cross_attn: Optional[Attention], + ) -> nn.ModuleList: """Configures transformer network.""" layers = nn.ModuleList([]) for layer_type in self.layer_types: if layer_type == "a": - layer = deepcopy(self.self_attn) + layer = deepcopy(self_attn) elif layer_type == "c": - layer = deepcopy(self.cross_attn) + layer = deepcopy(cross_attn) elif layer_type == "f": - layer = deepcopy(self.ff) - layers.append(nn.ModuleList([deepcopy(self.norm), layer, Residual()])) - self._delete() + layer = deepcopy(ff) + layers.append(nn.ModuleList([deepcopy(norm), layer, Residual()])) return layers def forward( |