diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-23 14:55:31 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-23 14:55:31 +0200 |
commit | a1d795bf02d14befc62cf600fb48842958148eba (patch) | |
tree | 21465c20262b15654985368731e8a289562e8df7 /text_recognizer/networks/cnn_tranformer.py | |
parent | d20802e1f412045f7afa4bd8ac50be3488945e90 (diff) |
Complete cnn-transformer network, not tested
Diffstat (limited to 'text_recognizer/networks/cnn_tranformer.py')
-rw-r--r-- | text_recognizer/networks/cnn_tranformer.py | 81 |
1 files changed, 64 insertions, 17 deletions
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py index 5c13e9a..e030cb8 100644 --- a/text_recognizer/networks/cnn_tranformer.py +++ b/text_recognizer/networks/cnn_tranformer.py @@ -3,6 +3,7 @@ import math from typing import Tuple, Type import attr +import torch from torch import nn, Tensor from text_recognizer.data.mappings import AbstractMapping @@ -18,13 +19,19 @@ class CnnTransformer(nn.Module): def __attrs_pre_init__(self) -> None: super().__init__() - # Parameters, + # Parameters and placeholders, input_dims: Tuple[int, int, int] = attr.ib() hidden_dim: int = attr.ib() dropout_rate: float = attr.ib() max_output_len: int = attr.ib() num_classes: int = attr.ib() padding_idx: int = attr.ib() + start_token: str = attr.ib() + start_index: int = attr.ib(init=False, default=None) + end_token: str = attr.ib() + end_index: int = attr.ib(init=False, default=None) + pad_token: str = attr.ib() + pad_index: int = attr.ib(init=False, default=None) # Modules. encoder: Type[nn.Module] = attr.ib() @@ -38,6 +45,9 @@ class CnnTransformer(nn.Module): def __attrs_post_init__(self) -> None: """Post init configuration.""" + self.start_index = int(self.mapping.get_index(self.start_token)) + self.end_index = int(self.mapping.get_index(self.end_token)) + self.pad_index = int(self.mapping.get_index(self.pad_token)) # Latent projector for down sampling number of filters and 2d # positional encoding. self.latent_encoder = nn.Sequential( @@ -99,20 +109,20 @@ class CnnTransformer(nn.Module): z = self.encoder(x) z = self.latent_encoder(z) - # Permute tensor from [B, E, Ho * Wo] to [Sx, B, E] - z = z.permute(2, 0, 1) + # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] + z = z.permute(0, 2, 1) return z - def decode(self, z: Tensor, trg: Tensor) -> Tensor: + def decode(self, z: Tensor, context: Tensor) -> Tensor: """Decodes latent images embedding into word pieces. Args: z (Tensor): Latent images embedding. - trg (Tensor): Word embeddings. + context (Tensor): Word embeddings. Shapes: - z: :math: `(B, Sx, E)` - - trg: :math: `(B, Sy)` + - context: :math: `(B, Sy)` - out: :math: `(B, Sy, T)` where Sy is the length of the output and T is the number of tokens. @@ -120,32 +130,69 @@ class CnnTransformer(nn.Module): Returns: Tensor: Sequence of word piece embeddings. """ - trg_mask = trg != self.padding_idx - trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim) - trg = self.token_pos_encoder(trg) - out = self.decoder(x=trg, context=z, mask=trg_mask) + context_mask = context != self.padding_idx + context = self.token_embedding(context) * math.sqrt(self.hidden_dim) + context = self.token_pos_encoder(context) + out = self.decoder(x=context, context=z, mask=context_mask) logits = self.head(out) return logits - def forward(self, x: Tensor, trg: Tensor) -> Tensor: + def forward(self, x: Tensor, context: Tensor) -> Tensor: """Encodes images into word piece logtis. Args: x (Tensor): Input image(s). - trg (Tensor): Target word embeddings. + context (Tensor): Target word embeddings. Shapes: - x: :math: `(B, C, H, W)` - - trg: :math: `(B, Sy, T)` + - context: :math: `(B, Sy, T)` where B is the batch size, C is the number of input channels, H is the image height and W is the image width. + + Returns: + Tensor: Sequence of logits. """ z = self.encode(x) - logits = self.decode(z, trg) + logits = self.decode(z, context) return logits def predict(self, x: Tensor) -> Tensor: - """Predicts text in image.""" - # TODO: continue here!!!!!!!!! - pass + """Predicts text in image. + + Args: + x (Tensor): Image(s) to extract text from. + + Shapes: + - x: :math: `(B, H, W)` + - output: :math: `(B, S)` + + Returns: + Tensor: A tensor of token indices of the predictions from the model. + """ + bsz = x.shape[0] + + # Encode image(s) to latent vectors. + z = self.encode(x) + + # Create a placeholder matrix for storing outputs from the network + output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) + output[:, 0] = self.start_index + + for i in range(1, self.max_output_len): + context = output[:, :i] # (bsz, i) + logits = self.decode(z, context) # (i, bsz, c) + tokens = torch.argmax(logits, dim=-1) # (i, bsz) + output[:, i : i + 1] = tokens[-1:] + + # Early stopping of prediction loop if token is end or padding token. + if (output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index).all(): + break + + # Set all tokens after end token to pad token. + for i in range(1, self.max_output_len): + idx = (output[:, i -1] == self.end_index | output[:, i - 1] == self.pad_index) + output[idx, i] = self.pad_index + + return output |