summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/layers.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/layers.py')
-rw-r--r--text_recognizer/networks/transformer/layers.py91
1 files changed, 48 insertions, 43 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index 4daa265..9b2f236 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -1,67 +1,74 @@
"""Transformer attention layer."""
from functools import partial
-from typing import Any, Dict, Optional, Tuple, Type
+from typing import Any, Dict, Optional, Tuple
+import attr
from torch import nn, Tensor
-from .attention import Attention
-from .mlp import FeedForward
-from .residual import Residual
-from .positional_encodings.rotary_embedding import RotaryEmbedding
+from text_recognizer.networks.transformer.residual import Residual
+from text_recognizer.networks.transformer.positional_encodings.rotary_embedding import (
+ RotaryEmbedding,
+)
+from text_recognizer.networks.util import load_partial_fn
+@attr.s
class AttentionLayers(nn.Module):
- def __init__(
- self,
- dim: int,
- depth: int,
- num_heads: int,
- ff_kwargs: Dict,
- attn_kwargs: Dict,
- attn_fn: Type[nn.Module] = Attention,
- norm_fn: Type[nn.Module] = nn.LayerNorm,
- ff_fn: Type[nn.Module] = FeedForward,
- rotary_emb: Optional[Type[nn.Module]] = None,
- rotary_emb_dim: Optional[int] = None,
- causal: bool = False,
- cross_attend: bool = False,
- pre_norm: bool = True,
- ) -> None:
+ """Standard transfomer layer."""
+
+ def __attrs_pre_init__(self) -> None:
super().__init__()
- self.dim = dim
- attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs)
- norm_fn = partial(norm_fn, dim)
- ff_fn = partial(ff_fn, dim=dim, **ff_kwargs)
- self.layer_types = self._get_layer_types(cross_attend) * depth
- self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn)
- rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None
- self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None
- self.pre_norm = pre_norm
- self.has_pos_emb = True if self.rotary_emb is not None else False
- @staticmethod
- def _get_layer_types(cross_attend: bool) -> Tuple:
+ dim: int = attr.ib()
+ depth: int = attr.ib()
+ num_heads: int = attr.ib()
+ attn_fn: str = attr.ib()
+ attn_kwargs: Dict = attr.ib()
+ norm_fn: str = attr.ib()
+ ff_fn: str = attr.ib()
+ ff_kwargs: Dict = attr.ib()
+ causal: bool = attr.ib(default=False)
+ cross_attend: bool = attr.ib(default=False)
+ pre_norm: bool = attr.ib(default=True)
+ rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None, init=False)
+ has_pos_emb: bool = attr.ib(init=False)
+ layer_types: Tuple[str, ...] = attr.ib(init=False)
+ layers: nn.ModuleList = attr.ib(init=False)
+ attn: partial = attr.ib(init=False)
+ norm: partial = attr.ib(init=False)
+ ff: partial = attr.ib(init=False)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ self.has_pos_emb = True if self.rotary_emb is not None else False
+ self.layer_types = self._get_layer_types() * self.depth
+ attn = load_partial_fn(
+ self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
+ )
+ norm = load_partial_fn(self.norm_fn, dim=self.dim)
+ ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
+ self.layers = self._build_network(attn, norm, ff)
+
+ def _get_layer_types(self) -> Tuple:
"""Get layer specification."""
- if cross_attend:
+ if self.cross_attend:
return "a", "c", "f"
return "a", "f"
def _build_network(
- self, causal: bool, attn_fn: partial, norm_fn: partial, ff_fn: partial,
+ self, attn: partial, norm: partial, ff: partial,
) -> nn.ModuleList:
"""Configures transformer network."""
layers = nn.ModuleList([])
for layer_type in self.layer_types:
if layer_type == "a":
- layer = attn_fn(causal=causal)
+ layer = attn(causal=self.causal)
elif layer_type == "c":
- layer = attn_fn()
+ layer = attn()
elif layer_type == "f":
- layer = ff_fn()
-
+ layer = ff()
residual_fn = Residual()
-
- layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
+ layers.append(nn.ModuleList([norm(), layer, residual_fn]))
return layers
def forward(
@@ -72,12 +79,10 @@ class AttentionLayers(nn.Module):
context_mask: Optional[Tensor] = None,
) -> Tensor:
rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None
-
for i, (layer_type, (norm, block, residual_fn)) in enumerate(
zip(self.layer_types, self.layers)
):
is_last = i == len(self.layers) - 1
-
residual = x
if self.pre_norm: