diff options
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 27 |
1 files changed, 11 insertions, 16 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index ce443e5..70a0ac7 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -1,5 +1,4 @@ """Transformer attention layer.""" -from functools import partial from typing import Any, Dict, Optional, Tuple import attr @@ -27,25 +26,17 @@ class AttentionLayers(nn.Module): norm_fn: str = attr.ib() ff_fn: str = attr.ib() ff_kwargs: Dict = attr.ib() + rotary_emb: Optional[RotaryEmbedding] = attr.ib() causal: bool = attr.ib(default=False) cross_attend: bool = attr.ib(default=False) pre_norm: bool = attr.ib(default=True) - rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None) layer_types: Tuple[str, ...] = attr.ib(init=False) layers: nn.ModuleList = attr.ib(init=False) - attn: partial = attr.ib(init=False) - norm: partial = attr.ib(init=False) - ff: partial = attr.ib(init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" self.layer_types = self._get_layer_types() * self.depth - attn = load_partial_fn( - self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs - ) - norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim) - ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs) - self.layers = self._build_network(attn, norm, ff) + self.layers = self._build_network() def _get_layer_types(self) -> Tuple: """Get layer specification.""" @@ -53,10 +44,13 @@ class AttentionLayers(nn.Module): return "a", "c", "f" return "a", "f" - def _build_network( - self, attn: partial, norm: partial, ff: partial, - ) -> nn.ModuleList: + def _build_network(self) -> nn.ModuleList: """Configures transformer network.""" + attn = load_partial_fn( + self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs + ) + norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim) + ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs) layers = nn.ModuleList([]) for layer_type in self.layer_types: if layer_type == "a": @@ -106,6 +100,7 @@ class Encoder(AttentionLayers): causal: bool = attr.ib(default=False, init=False) -@attr.s(auto_attribs=True, eq=False) class Decoder(AttentionLayers): - causal: bool = attr.ib(default=True, init=False) + def __init__(self, **kwargs: Any) -> None: + assert "causal" not in kwargs, "Cannot set causality on decoder" + super().__init__(causal=True, **kwargs) |