diff options
Diffstat (limited to 'text_recognizer/networks/transformer/layers.py')
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 16 |
1 files changed, 1 insertions, 15 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 8387fa4..67558ad 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -6,7 +6,6 @@ import attr from torch import nn, Tensor from text_recognizer.networks.transformer.attention import Attention -from text_recognizer.networks.transformer.local_attention import LocalAttention from text_recognizer.networks.transformer.mlp import FeedForward from text_recognizer.networks.transformer.residual import Residual @@ -24,18 +23,13 @@ class AttentionLayers(nn.Module): norm: Type[nn.Module] = attr.ib() ff: FeedForward = attr.ib() cross_attn: Optional[Attention] = attr.ib(default=None) - local_self_attn: Optional[LocalAttention] = attr.ib(default=None) pre_norm: bool = attr.ib(default=True) - local_depth: Optional[int] = attr.ib(default=None) has_pos_emb: bool = attr.ib(default=False) layer_types: Tuple[str, ...] = attr.ib(init=False) layers: nn.ModuleList = attr.ib(init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" - if self.local_self_attn is not None: - if self.local_depth is None: - ValueError("Local depth has to be specified") self.layer_types = self._get_layer_types() * self.depth self.layers = self._build_network() @@ -45,14 +39,8 @@ class AttentionLayers(nn.Module): return "a", "c", "f" return "a", "f" - def _self_attn_block(self, i: int) -> Type[nn.Module]: - if self.local_depth is not None and i < self.local_depth: - return deepcopy(self.local_self_attn) - return deepcopy(self.self_attn) - def _delete(self) -> None: del self.self_attn - del self.local_self_attn del self.ff del self.norm del self.cross_attn @@ -60,11 +48,9 @@ class AttentionLayers(nn.Module): def _build_network(self) -> nn.ModuleList: """Configures transformer network.""" layers = nn.ModuleList([]) - self_attn_depth = 0 for layer_type in self.layer_types: if layer_type == "a": - layer = self._self_attn_block(self_attn_depth) - self_attn_depth += 1 + layer = deepcopy(self.self_attn) elif layer_type == "c": layer = deepcopy(self.cross_attn) elif layer_type == "f": |