diff options
-rw-r--r-- | README.md | 2 | ||||
-rw-r--r-- | text_recognizer/models/greedy_decoder.py | 51 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 65 | ||||
-rw-r--r-- | training/conf/config.yaml | 1 | ||||
-rw-r--r-- | training/run.py | 10 |
5 files changed, 74 insertions, 55 deletions
@@ -73,7 +73,7 @@ Ideas of mine that did not work unfortunately: - [ ] Evaluation - [ ] Wandb artifact fetcher - [ ] fix linting -- [ ] Modularize the decoder +- [x] Modularize the decoder - [ ] Add kv cache - [x] Fix stems - [x] residual attn diff --git a/text_recognizer/models/greedy_decoder.py b/text_recognizer/models/greedy_decoder.py new file mode 100644 index 0000000..9d2f192 --- /dev/null +++ b/text_recognizer/models/greedy_decoder.py @@ -0,0 +1,51 @@ +"""Greedy decoder.""" +from typing import Type +from text_recognizer.data.tokenizer import Tokenizer +import torch +from torch import nn, Tensor + + +class GreedyDecoder: + def __init__( + self, + network: Type[nn.Module], + tokenizer: Tokenizer, + max_output_len: int = 682, + ) -> None: + self.network = network + self.start_index = tokenizer.start_index + self.end_index = tokenizer.end_index + self.pad_index = tokenizer.pad_index + self.max_output_len = max_output_len + + def __call__(self, x: Tensor) -> Tensor: + bsz = x.shape[0] + + # Encode image(s) to latent vectors. + img_features = self.network.encode(x) + + # Create a placeholder matrix for storing outputs from the network + indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) + indecies[:, 0] = self.start_index + + for Sy in range(1, self.max_output_len): + tokens = indecies[:, :Sy] # (B, Sy) + logits = self.network.decode(tokens, img_features) # (B, C, Sy) + indecies_ = torch.argmax(logits, dim=1) # (B, Sy) + indecies[:, Sy : Sy + 1] = indecies_[:, -1:] + + # Early stopping of prediction loop if token is end or padding token. + if ( + (indecies[:, Sy - 1] == self.end_index) + | (indecies[:, Sy - 1] == self.pad_index) + ).all(): + break + + # Set all tokens after end token to pad token. + for Sy in range(1, self.max_output_len): + idx = (indecies[:, Sy - 1] == self.end_index) | ( + indecies[:, Sy - 1] == self.pad_index + ) + indecies[idx, Sy] = self.pad_index + + return indecies diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 6048901..bbfaac0 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,6 +1,6 @@ """Lightning model for base Transformers.""" -from collections.abc import Sequence -from typing import Optional, Tuple, Type +from typing import Callable, Optional, Sequence, Tuple, Type +from text_recognizer.models.greedy_decoder import GreedyDecoder import torch from omegaconf import DictConfig @@ -20,6 +20,7 @@ class LitTransformer(LitBase): loss_fn: Type[nn.Module], optimizer_config: DictConfig, tokenizer: Tokenizer, + decoder: Callable = GreedyDecoder, lr_scheduler_config: Optional[DictConfig] = None, max_output_len: int = 682, ) -> None: @@ -35,9 +36,10 @@ class LitTransformer(LitBase): self.test_cer = CharErrorRate() self.val_wer = WordErrorRate() self.test_wer = WordErrorRate() + self.decoder = decoder def forward(self, data: Tensor) -> Tensor: - """Forward pass with the transformer network.""" + """Autoregressive forward pass.""" return self.predict(data) def teacher_forward(self, data: Tensor, targets: Tensor) -> Tensor: @@ -59,7 +61,7 @@ class LitTransformer(LitBase): logits = self.teacher_forward(data, targets[:, :-1]) loss = self.loss_fn(logits, targets[:, 1:]) preds = self.predict(data) - pred_text, target_text = self._get_text(preds), self._get_text(targets) + pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets) self.val_acc(preds, targets) self.val_cer(pred_text, target_text) @@ -76,7 +78,7 @@ class LitTransformer(LitBase): logits = self.teacher_forward(data, targets[:, :-1]) loss = self.loss_fn(logits, targets[:, 1:]) preds = self(data) - pred_text, target_text = self._get_text(preds), self._get_text(targets) + pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets) self.test_acc(preds, targets) self.test_cer(pred_text, target_text) @@ -86,55 +88,12 @@ class LitTransformer(LitBase): self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True) - def _get_text( + def _to_tokens( self, - xs: Tensor, - ) -> Tuple[Sequence[str], Sequence[str]]: - return [self.tokenizer.decode(x) for x in xs] + indecies: Tensor, + ) -> Sequence[str]: + return [self.tokenizer.decode(i) for i in indecies] @torch.no_grad() def predict(self, x: Tensor) -> Tensor: - """Predicts text in image. - - Args: - x (Tensor): Image(s) to extract text from. - - Shapes: - - x: :math: `(B, H, W)` - - output: :math: `(B, S)` - - Returns: - Tensor: A tensor of token indices of the predictions from the model. - """ - start_index = self.tokenizer.start_index - end_index = self.tokenizer.end_index - pad_index = self.tokenizer.pad_index - bsz = x.shape[0] - - # Encode image(s) to latent vectors. - img_features = self.network.encode(x) - - # Create a placeholder matrix for storing outputs from the network - indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) - indecies[:, 0] = start_index - - for Sy in range(1, self.max_output_len): - tokens = indecies[:, :Sy] # (B, Sy) - logits = self.network.decode(tokens, img_features) # (B, C, Sy) - indecies_ = torch.argmax(logits, dim=1) # (B, Sy) - indecies[:, Sy : Sy + 1] = indecies_[:, -1:] - - # Early stopping of prediction loop if token is end or padding token. - if ( - (indecies[:, Sy - 1] == end_index) | (indecies[:, Sy - 1] == pad_index) - ).all(): - break - - # Set all tokens after end token to pad token. - for Sy in range(1, self.max_output_len): - idx = (indecies[:, Sy - 1] == end_index) | ( - indecies[:, Sy - 1] == pad_index - ) - indecies[idx, Sy] = pad_index - - return indecies + return self.decoder(x) diff --git a/training/conf/config.yaml b/training/conf/config.yaml index a95307f..e57a8a8 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -4,6 +4,7 @@ defaults: - _self_ - callbacks: default - criterion: cross_entropy + - decoder: greedy - datamodule: iam_extended_paragraphs - hydra: default - logger: wandb diff --git a/training/run.py b/training/run.py index 429b1a2..288a1ef 100644 --- a/training/run.py +++ b/training/run.py @@ -1,5 +1,5 @@ """Script to run experiments.""" -from typing import List, Optional, Type +from typing import Callable, List, Optional, Type import hydra from loguru import logger as log @@ -34,11 +34,19 @@ def run(config: DictConfig) -> Optional[float]: log.info(f"Instantiating criterion <{config.criterion._target_}>") loss_fn: Type[nn.Module] = hydra.utils.instantiate(config.criterion) + log.info(f"Instantiating decoder <{config.criterion._target_}>") + decoder: Type[Callable] = hydra.utils.instantiate( + config.decoder, + network=network, + tokenizer=datamodule.tokenizer, + ) + log.info(f"Instantiating model <{config.model._target_}>") model: LightningModule = hydra.utils.instantiate( config.model, network=network, tokenizer=datamodule.tokenizer, + decoder=decoder, loss_fn=loss_fn, optimizer_config=config.optimizer, lr_scheduler_config=config.lr_scheduler, |