summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-09 18:50:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-09 18:50:55 +0200
commita2a3133ed5da283888efbdb9924d0e3733c274c8 (patch)
treef6b49a227b08ff2e1a1c5809a576de6a2061ccf4 /text_recognizer/networks/transformer
parent548f52b35062e258622ea638ed1b132d6759a07a (diff)
tranformer layer done
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r--text_recognizer/networks/transformer/layers.py42
1 files changed, 35 insertions, 7 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index 1c951ae..a2fdb1a 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -10,6 +10,7 @@ from torch import nn, Tensor
from .attention import Attention
from .mlp import FeedForward
from .residual import Residual
+from .rotary_embedding import RotaryEmbedding
class AttentionLayers(nn.Module):
@@ -24,17 +25,23 @@ class AttentionLayers(nn.Module):
norm_fn: Type[nn.Module] = nn.LayerNorm,
ff_fn: Type[nn.Module] = FeedForward,
residual_fn: Type[nn.Module] = Residual,
+ rotary_emb: Optional[Type[nn.Module]] = None,
+ rotary_emb_dim: Optional[int] = None,
causal: bool = False,
cross_attend: bool = False,
+ pre_norm: bool = True,
) -> None:
super().__init__()
attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs)
norm_fn = partial(norm_fn, dim=dim)
ff_fn = partial(ff_fn, dim=dim, **ff_kwargs)
- layer_types = self._get_layer_types(cross_attend) * depth
+ self.layer_types = self._get_layer_types(cross_attend) * depth
self.layers = self._build_network(
- layer_types, causal, attn_fn, norm_fn, ff_fn, residual_fn
+ causal, attn_fn, norm_fn, ff_fn, residual_fn
)
+ rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None
+ self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None
+ self.pre_norm = pre_norm
@staticmethod
def _get_layer_types(cross_attend: bool) -> Tuple:
@@ -43,18 +50,17 @@ class AttentionLayers(nn.Module):
return "a", "c", "f"
return "a", "f"
- @staticmethod
def _build_network(
- layer_types: Tuple,
+ self,
causal: bool,
attn_fn: partial,
norm_fn: partial,
ff_fn: partial,
residual_fn: Type[nn.Module],
) -> nn.ModuleList:
- """Configures transformer layers."""
+ """Configures transformer network."""
layers = nn.ModuleList([])
- for layer_type in layer_types:
+ for layer_type in self.layer_types:
if layer_type == "a":
layer = attn_fn(causal=causal)
elif layer_type == "c":
@@ -74,4 +80,26 @@ class AttentionLayers(nn.Module):
mask: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
) -> Tensor:
- pass
+ rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None
+
+ for i, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
+ is_last = i == len(self.layers) - 1
+
+ residual = x
+
+ if self.pre_norm:
+ x = norm(x)
+
+ if layer_type == "a":
+ out, _ = block(x=x, mask=mask, rotary_pos_emb=rotary_pos_emb)
+ elif layer_type == "c":
+ out, _ = block(x, context=context, mask=mask, context_mask=context_mask)
+ elif layer_type == "f":
+ out = block(x)
+
+ x = residual_fn(out, residual)
+
+ if not self.pre_norm and not is_last:
+ x = norm(x)
+
+ return x