diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/transformer/transformer.py | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py new file mode 100644 index 0000000..31088b4 --- /dev/null +++ b/text_recognizer/networks/transformer/transformer.py @@ -0,0 +1,62 @@ +"""Transformer wrapper.""" +from typing import Any, Optional, Type + +from torch import nn, Tensor + +from .layers import AttentionLayers +from text_recognizer.networks.transformer.positional_encodings import ( + AbsolutePositionalEmbedding, +) + + +class Transformer(nn.Module): + def __init__( + self, + num_tokens: int, + max_seq_len: int, + attn_layers: Type[AttentionLayers], + emb_dim: Optional[int] = None, + emb_dropout: float = 0.0, + use_pos_emb: bool = True, + ) -> None: + super().__init__() + dim = attn_layers.dim + self.attn_layers = attn_layers + emb_dim = emb_dim if emb_dim is not None else dim + self.max_seq_len = max_seq_len + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.emb_dropout = nn.Dropout(emb_dropout) + self.pos_emb = ( + AbsolutePositionalEmbedding(emb_dim, max_seq_len) + if (use_pos_emb and not self.attn_layers.has_pos_emb) + else None + ) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.norm = nn.LayerNorm(dim) + + self._init_weights() + + self.logits = nn.Linear(dim, num_tokens) + + def _init_weights(self) -> None: + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x: Tensor, + mask: Optional[Tensor] = None, + return_embeddings: bool = False, + **kwargs: Any + ) -> Tensor: + b, n, device = *x.shape, x.device + x = self.token_emb(x) + if self.pos_emb is not None: + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + x = self.attn_layers(x, mask=mask, **kwargs) + out = self.logits(x) if not return_embeddings else x + return out |