summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/layers.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/layers.py')
-rw-r--r--text_recognizer/networks/transformer/layers.py29
1 files changed, 21 insertions, 8 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index a2fdb1a..4063425 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -1,6 +1,6 @@
"""Generates the attention layer architecture."""
from functools import partial
-from typing import Dict, Optional, Type
+from typing import Any, Dict, Optional, Type
from click.types import Tuple
@@ -36,12 +36,11 @@ class AttentionLayers(nn.Module):
norm_fn = partial(norm_fn, dim=dim)
ff_fn = partial(ff_fn, dim=dim, **ff_kwargs)
self.layer_types = self._get_layer_types(cross_attend) * depth
- self.layers = self._build_network(
- causal, attn_fn, norm_fn, ff_fn, residual_fn
- )
+ self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn, residual_fn)
rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None
self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None
self.pre_norm = pre_norm
+ self.has_pos_emb = True if self.rotary_emb is not None else False
@staticmethod
def _get_layer_types(cross_attend: bool) -> Tuple:
@@ -70,7 +69,7 @@ class AttentionLayers(nn.Module):
residual_fn = residual_fn()
- layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
+ layers.append(nn.modulelist([norm_fn(), layer, residual_fn]))
return layers
def forward(
@@ -82,10 +81,12 @@ class AttentionLayers(nn.Module):
) -> Tensor:
rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None
- for i, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
+ for i, (layer_type, (norm, block, residual_fn)) in enumerate(
+ zip(self.layer_types, self.layers)
+ ):
is_last = i == len(self.layers) - 1
-
- residual = x
+
+ residual = x
if self.pre_norm:
x = norm(x)
@@ -103,3 +104,15 @@ class AttentionLayers(nn.Module):
x = norm(x)
return x
+
+
+class Encoder(AttentionLayers):
+ def __init__(self, **kwargs: Any) -> None:
+ assert "causal" not in kwargs, "Cannot set causality on encoder"
+ super().__init__(causal=False, **kwargs)
+
+
+class Decoder(AttentionLayers):
+ def __init__(self, **kwargs: Any) -> None:
+ assert "causal" not in kwargs, "Cannot set causality on decoder"
+ super().__init__(causal=True, **kwargs)