diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-12-05 20:23:26 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-12-05 20:23:26 +0100 |
commit | 8efa71d428b4cfa5586431853fc6e2914ba0b3ee (patch) | |
tree | 3be78a49558f9b84c74cd6f45e13daf870fe3b29 | |
parent | f0e006105b68a6e86a8c50f1a365fed0f65da460 (diff) |
Add base network
-rw-r--r-- | text_recognizer/networks/base.py | 103 |
1 files changed, 103 insertions, 0 deletions
diff --git a/text_recognizer/networks/base.py b/text_recognizer/networks/base.py new file mode 100644 index 0000000..f6f1831 --- /dev/null +++ b/text_recognizer/networks/base.py @@ -0,0 +1,103 @@ +"""Base network module.""" +import math +from typing import Optional, Tuple, Type + +from loguru import logger as log +from torch import nn, Tensor + +from text_recognizer.networks.transformer.layers import Decoder + + +class BaseTransformer(nn.Module): + def __init__( + self, + input_dims: Tuple[int, int, int], + hidden_dim: int, + num_classes: int, + pad_index: Tensor, + encoder: Type[nn.Module], + decoder: Decoder, + token_pos_embedding: Optional[Type[nn.Module]] = None, + ) -> None: + super().__init__() + self.input_dims = input_dims + self.hidden_dim = hidden_dim + self.num_classes = num_classes + self.pad_index = pad_index + self.encoder = encoder + self.decoder = decoder + + # Token embedding. + self.token_embedding = nn.Embedding( + num_embeddings=self.num_classes, embedding_dim=self.hidden_dim + ) + + # Positional encoding for decoder tokens. + if not self.decoder.has_pos_emb: + self.token_pos_embedding = token_pos_embedding + else: + self.token_pos_embedding = None + log.debug("Decoder already have a positional embedding.") + + self.norm = nn.LayerNorm(self.hidden_dim) + + # Output layer + self.to_logits = nn.Linear( + in_features=self.hidden_dim, out_features=self.num_classes + ) + + def encode(self, x: Tensor) -> Tensor: + """Encodes images with encoder.""" + return self.encoder(x) + + def decode(self, src: Tensor, trg: Tensor) -> Tensor: + """Decodes latent images embedding into word pieces. + + Args: + src (Tensor): Latent images embedding. + trg (Tensor): Word embeddings. + + Shapes: + - z: :math: `(B, Sx, E)` + - context: :math: `(B, Sy)` + - out: :math: `(B, Sy, T)` + + where Sy is the length of the output and T is the number of tokens. + + Returns: + Tensor: Sequence of word piece embeddings. + """ + trg = trg.long() + trg_mask = trg != self.pad_index + trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim) + trg = ( + self.token_pos_embedding(trg) + if self.token_pos_embedding is not None + else trg + ) + out = self.decoder(x=trg, context=src, input_mask=trg_mask) + out = self.norm(out) + logits = self.to_logits(out) # [B, Sy, T] + logits = logits.permute(0, 2, 1) # [B, T, Sy] + return logits + + def forward(self, x: Tensor, context: Tensor) -> Tensor: + """Encodes images into word piece logtis. + + Args: + x (Tensor): Input image(s). + context (Tensor): Target word embeddings. + + Shapes: + - x: :math: `(B, C, H, W)` + - context: :math: `(B, Sy, T)` + + where B is the batch size, C is the number of input channels, H is + the image height and W is the image width. + + Returns: + Tensor: Sequence of logits. + """ + z = self.encode(x) + logits = self.decode(z, context) + return logits |