diff options
Diffstat (limited to 'text_recognizer/networks/transformer')
| -rw-r--r-- | text_recognizer/networks/transformer/transformer.py | 62 | 
1 files changed, 62 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py new file mode 100644 index 0000000..31088b4 --- /dev/null +++ b/text_recognizer/networks/transformer/transformer.py @@ -0,0 +1,62 @@ +"""Transformer wrapper.""" +from typing import Any, Optional, Type + +from torch import nn, Tensor + +from .layers import AttentionLayers +from text_recognizer.networks.transformer.positional_encodings import ( +    AbsolutePositionalEmbedding, +) + + +class Transformer(nn.Module): +    def __init__( +        self, +        num_tokens: int, +        max_seq_len: int, +        attn_layers: Type[AttentionLayers], +        emb_dim: Optional[int] = None, +        emb_dropout: float = 0.0, +        use_pos_emb: bool = True, +    ) -> None: +        super().__init__() +        dim = attn_layers.dim +        self.attn_layers = attn_layers +        emb_dim = emb_dim if emb_dim is not None else dim +        self.max_seq_len = max_seq_len + +        self.token_emb = nn.Embedding(num_tokens, emb_dim) +        self.emb_dropout = nn.Dropout(emb_dropout) +        self.pos_emb = ( +            AbsolutePositionalEmbedding(emb_dim, max_seq_len) +            if (use_pos_emb and not self.attn_layers.has_pos_emb) +            else None +        ) + +        self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() +        self.norm = nn.LayerNorm(dim) + +        self._init_weights() + +        self.logits = nn.Linear(dim, num_tokens) + +    def _init_weights(self) -> None: +        nn.init.normal_(self.token_emb.weight, std=0.02) + +    def forward( +        self, +        x: Tensor, +        mask: Optional[Tensor] = None, +        return_embeddings: bool = False, +        **kwargs: Any +    ) -> Tensor: +        b, n, device = *x.shape, x.device +        x = self.token_emb(x) +        if self.pos_emb is not None: +            x += self.pos_emb(x) +        x = self.emb_dropout(x) + +        x = self.project_emb(x) +        x = self.attn_layers(x, mask=mask, **kwargs) +        out = self.logits(x) if not return_embeddings else x +        return out  |