diff options
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r-- | text_recognizer/models/transformer.py | 80 |
1 files changed, 70 insertions, 10 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index f5cb491..7a9d566 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,11 +1,11 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Dict, List, Optional, Sequence, Union, Tuple, Type +from typing import Sequence, Tuple, Type import attr -import hydra -from omegaconf import DictConfig -from torch import nn, Tensor +import torch +from torch import Tensor +from text_recognizer.data.mappings import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import BaseLitModel @@ -13,18 +13,31 @@ from text_recognizer.models.base import BaseLitModel @attr.s(auto_attribs=True) class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" + mapping: Type[AbstractMapping] = attr.ib() + start_token: str = attr.ib() + end_token: str = attr.ib() + pad_token: str = attr.ib() - ignore_tokens: Sequence[str] = attr.ib(default=("<s>", "<e>", "<p>",)) + start_index: Tensor = attr.ib(init=False) + end_index: Tensor = attr.ib(init=False) + pad_index: Tensor = attr.ib(init=False) + + ignore_indices: Sequence[str] = attr.ib(init=False) val_cer: CharacterErrorRate = attr.ib(init=False) test_cer: CharacterErrorRate = attr.ib(init=False) def __attrs_post_init__(self) -> None: - self.val_cer = CharacterErrorRate(self.ignore_tokens) - self.test_cer = CharacterErrorRate(self.ignore_tokens) + """Post init configuration.""" + self.start_index = self.mapping.get_index(self.start_token) + self.end_index = self.mapping.get_index(self.end_token) + self.pad_index = self.mapping.get_index(self.pad_token) + self.ignore_indices = [self.start_index, self.end_index, self.pad_index] + self.val_cer = CharacterErrorRate(self.ignore_indices) + self.test_cer = CharacterErrorRate(self.ignore_indices) def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" - return self.network.predict(data) + return self.predict(data) def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" @@ -38,17 +51,64 @@ class TransformerLitModel(BaseLitModel): """Validation step.""" data, targets = batch + # Compute the loss. logits = self.network(data, targets[:-1]) loss = self.loss_fn(logits, targets[1:]) self.log("val/loss", loss, prog_bar=True) - pred = self.network.predict(data) + # Get the token prediction. + pred = self(data) self.val_cer(pred, targets) self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" data, targets = batch - pred = self.network.predict(data) + + # Compute the text prediction. + pred = self(data) self.test_cer(pred, targets) self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) + + def predict(self, x: Tensor) -> Tensor: + """Predicts text in image. + + Args: + x (Tensor): Image(s) to extract text from. + + Shapes: + - x: :math: `(B, H, W)` + - output: :math: `(B, S)` + + Returns: + Tensor: A tensor of token indices of the predictions from the model. + """ + bsz = x.shape[0] + + # Encode image(s) to latent vectors. + z = self.network.encode(x) + + # Create a placeholder matrix for storing outputs from the network + output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) + output[:, 0] = self.start_index + + for i in range(1, self.max_output_len): + context = output[:, :i] # (bsz, i) + logits = self.network.decode(z, context) # (i, bsz, c) + tokens = torch.argmax(logits, dim=-1) # (i, bsz) + output[:, i : i + 1] = tokens[-1:] + + # Early stopping of prediction loop if token is end or padding token. + if ( + output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index + ).all(): + break + + # Set all tokens after end token to pad token. + for i in range(1, self.max_output_len): + idx = ( + output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index + ) + output[idx, i] = self.pad_index + + return output |