diff options
Diffstat (limited to 'text_recognizer/networks/transformer/layers.py')
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py new file mode 100644 index 0000000..1c951ae --- /dev/null +++ b/text_recognizer/networks/transformer/layers.py @@ -0,0 +1,77 @@ +"""Generates the attention layer architecture.""" +from functools import partial +from typing import Dict, Optional, Type + +from click.types import Tuple + +import torch +from torch import nn, Tensor + +from .attention import Attention +from .mlp import FeedForward +from .residual import Residual + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim: int, + depth: int, + num_heads: int, + ff_kwargs: Dict, + attn_kwargs: Dict, + attn_fn: Type[nn.Module] = Attention, + norm_fn: Type[nn.Module] = nn.LayerNorm, + ff_fn: Type[nn.Module] = FeedForward, + residual_fn: Type[nn.Module] = Residual, + causal: bool = False, + cross_attend: bool = False, + ) -> None: + super().__init__() + attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs) + norm_fn = partial(norm_fn, dim=dim) + ff_fn = partial(ff_fn, dim=dim, **ff_kwargs) + layer_types = self._get_layer_types(cross_attend) * depth + self.layers = self._build_network( + layer_types, causal, attn_fn, norm_fn, ff_fn, residual_fn + ) + + @staticmethod + def _get_layer_types(cross_attend: bool) -> Tuple: + """Get layer specification.""" + if cross_attend: + return "a", "c", "f" + return "a", "f" + + @staticmethod + def _build_network( + layer_types: Tuple, + causal: bool, + attn_fn: partial, + norm_fn: partial, + ff_fn: partial, + residual_fn: Type[nn.Module], + ) -> nn.ModuleList: + """Configures transformer layers.""" + layers = nn.ModuleList([]) + for layer_type in layer_types: + if layer_type == "a": + layer = attn_fn(causal=causal) + elif layer_type == "c": + layer = attn_fn() + elif layer_type == "f": + layer = ff_fn() + + residual_fn = residual_fn() + + layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) + return layers + + def forward( + self, + x: Tensor, + context: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + ) -> Tensor: + pass |