From 8c4a0c2603975cfc63f4e4019386e001387c42c9 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 4 Oct 2022 22:08:38 +0200 Subject: Add greedy decoder --- text_recognizer/models/greedy_decoder.py | 51 +++++++++++++++++++++++++ text_recognizer/models/transformer.py | 65 ++++++-------------------------- 2 files changed, 63 insertions(+), 53 deletions(-) create mode 100644 text_recognizer/models/greedy_decoder.py (limited to 'text_recognizer') diff --git a/text_recognizer/models/greedy_decoder.py b/text_recognizer/models/greedy_decoder.py new file mode 100644 index 0000000..9d2f192 --- /dev/null +++ b/text_recognizer/models/greedy_decoder.py @@ -0,0 +1,51 @@ +"""Greedy decoder.""" +from typing import Type +from text_recognizer.data.tokenizer import Tokenizer +import torch +from torch import nn, Tensor + + +class GreedyDecoder: + def __init__( + self, + network: Type[nn.Module], + tokenizer: Tokenizer, + max_output_len: int = 682, + ) -> None: + self.network = network + self.start_index = tokenizer.start_index + self.end_index = tokenizer.end_index + self.pad_index = tokenizer.pad_index + self.max_output_len = max_output_len + + def __call__(self, x: Tensor) -> Tensor: + bsz = x.shape[0] + + # Encode image(s) to latent vectors. + img_features = self.network.encode(x) + + # Create a placeholder matrix for storing outputs from the network + indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) + indecies[:, 0] = self.start_index + + for Sy in range(1, self.max_output_len): + tokens = indecies[:, :Sy] # (B, Sy) + logits = self.network.decode(tokens, img_features) # (B, C, Sy) + indecies_ = torch.argmax(logits, dim=1) # (B, Sy) + indecies[:, Sy : Sy + 1] = indecies_[:, -1:] + + # Early stopping of prediction loop if token is end or padding token. + if ( + (indecies[:, Sy - 1] == self.end_index) + | (indecies[:, Sy - 1] == self.pad_index) + ).all(): + break + + # Set all tokens after end token to pad token. + for Sy in range(1, self.max_output_len): + idx = (indecies[:, Sy - 1] == self.end_index) | ( + indecies[:, Sy - 1] == self.pad_index + ) + indecies[idx, Sy] = self.pad_index + + return indecies diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 6048901..bbfaac0 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,6 +1,6 @@ """Lightning model for base Transformers.""" -from collections.abc import Sequence -from typing import Optional, Tuple, Type +from typing import Callable, Optional, Sequence, Tuple, Type +from text_recognizer.models.greedy_decoder import GreedyDecoder import torch from omegaconf import DictConfig @@ -20,6 +20,7 @@ class LitTransformer(LitBase): loss_fn: Type[nn.Module], optimizer_config: DictConfig, tokenizer: Tokenizer, + decoder: Callable = GreedyDecoder, lr_scheduler_config: Optional[DictConfig] = None, max_output_len: int = 682, ) -> None: @@ -35,9 +36,10 @@ class LitTransformer(LitBase): self.test_cer = CharErrorRate() self.val_wer = WordErrorRate() self.test_wer = WordErrorRate() + self.decoder = decoder def forward(self, data: Tensor) -> Tensor: - """Forward pass with the transformer network.""" + """Autoregressive forward pass.""" return self.predict(data) def teacher_forward(self, data: Tensor, targets: Tensor) -> Tensor: @@ -59,7 +61,7 @@ class LitTransformer(LitBase): logits = self.teacher_forward(data, targets[:, :-1]) loss = self.loss_fn(logits, targets[:, 1:]) preds = self.predict(data) - pred_text, target_text = self._get_text(preds), self._get_text(targets) + pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets) self.val_acc(preds, targets) self.val_cer(pred_text, target_text) @@ -76,7 +78,7 @@ class LitTransformer(LitBase): logits = self.teacher_forward(data, targets[:, :-1]) loss = self.loss_fn(logits, targets[:, 1:]) preds = self(data) - pred_text, target_text = self._get_text(preds), self._get_text(targets) + pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets) self.test_acc(preds, targets) self.test_cer(pred_text, target_text) @@ -86,55 +88,12 @@ class LitTransformer(LitBase): self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True) - def _get_text( + def _to_tokens( self, - xs: Tensor, - ) -> Tuple[Sequence[str], Sequence[str]]: - return [self.tokenizer.decode(x) for x in xs] + indecies: Tensor, + ) -> Sequence[str]: + return [self.tokenizer.decode(i) for i in indecies] @torch.no_grad() def predict(self, x: Tensor) -> Tensor: - """Predicts text in image. - - Args: - x (Tensor): Image(s) to extract text from. - - Shapes: - - x: :math: `(B, H, W)` - - output: :math: `(B, S)` - - Returns: - Tensor: A tensor of token indices of the predictions from the model. - """ - start_index = self.tokenizer.start_index - end_index = self.tokenizer.end_index - pad_index = self.tokenizer.pad_index - bsz = x.shape[0] - - # Encode image(s) to latent vectors. - img_features = self.network.encode(x) - - # Create a placeholder matrix for storing outputs from the network - indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) - indecies[:, 0] = start_index - - for Sy in range(1, self.max_output_len): - tokens = indecies[:, :Sy] # (B, Sy) - logits = self.network.decode(tokens, img_features) # (B, C, Sy) - indecies_ = torch.argmax(logits, dim=1) # (B, Sy) - indecies[:, Sy : Sy + 1] = indecies_[:, -1:] - - # Early stopping of prediction loop if token is end or padding token. - if ( - (indecies[:, Sy - 1] == end_index) | (indecies[:, Sy - 1] == pad_index) - ).all(): - break - - # Set all tokens after end token to pad token. - for Sy in range(1, self.max_output_len): - idx = (indecies[:, Sy - 1] == end_index) | ( - indecies[:, Sy - 1] == pad_index - ) - indecies[idx, Sy] = pad_index - - return indecies + return self.decoder(x) -- cgit v1.2.3-70-g09d2