summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/layers.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/layers.py')
-rw-r--r--text_recognizer/networks/transformer/layers.py27
1 files changed, 11 insertions, 16 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index ce443e5..70a0ac7 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -1,5 +1,4 @@
"""Transformer attention layer."""
-from functools import partial
from typing import Any, Dict, Optional, Tuple
import attr
@@ -27,25 +26,17 @@ class AttentionLayers(nn.Module):
norm_fn: str = attr.ib()
ff_fn: str = attr.ib()
ff_kwargs: Dict = attr.ib()
+ rotary_emb: Optional[RotaryEmbedding] = attr.ib()
causal: bool = attr.ib(default=False)
cross_attend: bool = attr.ib(default=False)
pre_norm: bool = attr.ib(default=True)
- rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None)
layer_types: Tuple[str, ...] = attr.ib(init=False)
layers: nn.ModuleList = attr.ib(init=False)
- attn: partial = attr.ib(init=False)
- norm: partial = attr.ib(init=False)
- ff: partial = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
self.layer_types = self._get_layer_types() * self.depth
- attn = load_partial_fn(
- self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
- )
- norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim)
- ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
- self.layers = self._build_network(attn, norm, ff)
+ self.layers = self._build_network()
def _get_layer_types(self) -> Tuple:
"""Get layer specification."""
@@ -53,10 +44,13 @@ class AttentionLayers(nn.Module):
return "a", "c", "f"
return "a", "f"
- def _build_network(
- self, attn: partial, norm: partial, ff: partial,
- ) -> nn.ModuleList:
+ def _build_network(self) -> nn.ModuleList:
"""Configures transformer network."""
+ attn = load_partial_fn(
+ self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
+ )
+ norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim)
+ ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
layers = nn.ModuleList([])
for layer_type in self.layer_types:
if layer_type == "a":
@@ -106,6 +100,7 @@ class Encoder(AttentionLayers):
causal: bool = attr.ib(default=False, init=False)
-@attr.s(auto_attribs=True, eq=False)
class Decoder(AttentionLayers):
- causal: bool = attr.ib(default=True, init=False)
+ def __init__(self, **kwargs: Any) -> None:
+ assert "causal" not in kwargs, "Cannot set causality on decoder"
+ super().__init__(causal=True, **kwargs)