diff options
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r-- | text_recognizer/networks/transformer/transformer.py | 62 |
1 files changed, 0 insertions, 62 deletions
diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py deleted file mode 100644 index 31088b4..0000000 --- a/text_recognizer/networks/transformer/transformer.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Transformer wrapper.""" -from typing import Any, Optional, Type - -from torch import nn, Tensor - -from .layers import AttentionLayers -from text_recognizer.networks.transformer.positional_encodings import ( - AbsolutePositionalEmbedding, -) - - -class Transformer(nn.Module): - def __init__( - self, - num_tokens: int, - max_seq_len: int, - attn_layers: Type[AttentionLayers], - emb_dim: Optional[int] = None, - emb_dropout: float = 0.0, - use_pos_emb: bool = True, - ) -> None: - super().__init__() - dim = attn_layers.dim - self.attn_layers = attn_layers - emb_dim = emb_dim if emb_dim is not None else dim - self.max_seq_len = max_seq_len - - self.token_emb = nn.Embedding(num_tokens, emb_dim) - self.emb_dropout = nn.Dropout(emb_dropout) - self.pos_emb = ( - AbsolutePositionalEmbedding(emb_dim, max_seq_len) - if (use_pos_emb and not self.attn_layers.has_pos_emb) - else None - ) - - self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() - self.norm = nn.LayerNorm(dim) - - self._init_weights() - - self.logits = nn.Linear(dim, num_tokens) - - def _init_weights(self) -> None: - nn.init.normal_(self.token_emb.weight, std=0.02) - - def forward( - self, - x: Tensor, - mask: Optional[Tensor] = None, - return_embeddings: bool = False, - **kwargs: Any - ) -> Tensor: - b, n, device = *x.shape, x.device - x = self.token_emb(x) - if self.pos_emb is not None: - x += self.pos_emb(x) - x = self.emb_dropout(x) - - x = self.project_emb(x) - x = self.attn_layers(x, mask=mask, **kwargs) - out = self.logits(x) if not return_embeddings else x - return out |