diff options
Diffstat (limited to 'text_recognizer/networks/transformer/layers.py')
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 91 |
1 files changed, 48 insertions, 43 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 4daa265..9b2f236 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -1,67 +1,74 @@ """Transformer attention layer.""" from functools import partial -from typing import Any, Dict, Optional, Tuple, Type +from typing import Any, Dict, Optional, Tuple +import attr from torch import nn, Tensor -from .attention import Attention -from .mlp import FeedForward -from .residual import Residual -from .positional_encodings.rotary_embedding import RotaryEmbedding +from text_recognizer.networks.transformer.residual import Residual +from text_recognizer.networks.transformer.positional_encodings.rotary_embedding import ( + RotaryEmbedding, +) +from text_recognizer.networks.util import load_partial_fn +@attr.s class AttentionLayers(nn.Module): - def __init__( - self, - dim: int, - depth: int, - num_heads: int, - ff_kwargs: Dict, - attn_kwargs: Dict, - attn_fn: Type[nn.Module] = Attention, - norm_fn: Type[nn.Module] = nn.LayerNorm, - ff_fn: Type[nn.Module] = FeedForward, - rotary_emb: Optional[Type[nn.Module]] = None, - rotary_emb_dim: Optional[int] = None, - causal: bool = False, - cross_attend: bool = False, - pre_norm: bool = True, - ) -> None: + """Standard transfomer layer.""" + + def __attrs_pre_init__(self) -> None: super().__init__() - self.dim = dim - attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs) - norm_fn = partial(norm_fn, dim) - ff_fn = partial(ff_fn, dim=dim, **ff_kwargs) - self.layer_types = self._get_layer_types(cross_attend) * depth - self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn) - rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None - self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None - self.pre_norm = pre_norm - self.has_pos_emb = True if self.rotary_emb is not None else False - @staticmethod - def _get_layer_types(cross_attend: bool) -> Tuple: + dim: int = attr.ib() + depth: int = attr.ib() + num_heads: int = attr.ib() + attn_fn: str = attr.ib() + attn_kwargs: Dict = attr.ib() + norm_fn: str = attr.ib() + ff_fn: str = attr.ib() + ff_kwargs: Dict = attr.ib() + causal: bool = attr.ib(default=False) + cross_attend: bool = attr.ib(default=False) + pre_norm: bool = attr.ib(default=True) + rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None, init=False) + has_pos_emb: bool = attr.ib(init=False) + layer_types: Tuple[str, ...] = attr.ib(init=False) + layers: nn.ModuleList = attr.ib(init=False) + attn: partial = attr.ib(init=False) + norm: partial = attr.ib(init=False) + ff: partial = attr.ib(init=False) + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self.has_pos_emb = True if self.rotary_emb is not None else False + self.layer_types = self._get_layer_types() * self.depth + attn = load_partial_fn( + self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs + ) + norm = load_partial_fn(self.norm_fn, dim=self.dim) + ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs) + self.layers = self._build_network(attn, norm, ff) + + def _get_layer_types(self) -> Tuple: """Get layer specification.""" - if cross_attend: + if self.cross_attend: return "a", "c", "f" return "a", "f" def _build_network( - self, causal: bool, attn_fn: partial, norm_fn: partial, ff_fn: partial, + self, attn: partial, norm: partial, ff: partial, ) -> nn.ModuleList: """Configures transformer network.""" layers = nn.ModuleList([]) for layer_type in self.layer_types: if layer_type == "a": - layer = attn_fn(causal=causal) + layer = attn(causal=self.causal) elif layer_type == "c": - layer = attn_fn() + layer = attn() elif layer_type == "f": - layer = ff_fn() - + layer = ff() residual_fn = Residual() - - layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) + layers.append(nn.ModuleList([norm(), layer, residual_fn])) return layers def forward( @@ -72,12 +79,10 @@ class AttentionLayers(nn.Module): context_mask: Optional[Tensor] = None, ) -> Tensor: rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None - for i, (layer_type, (norm, block, residual_fn)) in enumerate( zip(self.layer_types, self.layers) ): is_last = i == len(self.layers) - 1 - residual = x if self.pre_norm: |