diff options
Diffstat (limited to 'text_recognizer/model')
-rw-r--r-- | text_recognizer/model/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/model/base.py | 96 | ||||
-rw-r--r-- | text_recognizer/model/greedy_decoder.py | 58 | ||||
-rw-r--r-- | text_recognizer/model/transformer.py | 89 |
4 files changed, 244 insertions, 0 deletions
diff --git a/text_recognizer/model/__init__.py b/text_recognizer/model/__init__.py new file mode 100644 index 0000000..1982daf --- /dev/null +++ b/text_recognizer/model/__init__.py @@ -0,0 +1 @@ +"""PyTorch Lightning models modules.""" diff --git a/text_recognizer/model/base.py b/text_recognizer/model/base.py new file mode 100644 index 0000000..1cff796 --- /dev/null +++ b/text_recognizer/model/base.py @@ -0,0 +1,96 @@ +"""Base PyTorch Lightning model.""" +from typing import Any, Dict, Optional, Tuple, Type + +import hydra +import torch +from loguru import logger as log +from omegaconf import DictConfig +import pytorch_lightning as L +from torch import nn, Tensor +from torchmetrics import Accuracy + +from text_recognizer.data.tokenizer import Tokenizer + + +class LitBase(L.LightningModule): + """Abstract PyTorch Lightning class.""" + + def __init__( + self, + network: Type[nn.Module], + loss_fn: Type[nn.Module], + optimizer_config: DictConfig, + lr_scheduler_config: Optional[DictConfig], + tokenizer: Tokenizer, + ) -> None: + super().__init__() + + self.network = network + self.loss_fn = loss_fn + self.optimizer_config = optimizer_config + self.lr_scheduler_config = lr_scheduler_config + self.tokenizer = tokenizer + ignore_index = int(self.tokenizer.get_value("<p>")) + # Placeholders + self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) + self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) + self.test_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) + + def optimizer_zero_grad( + self, + epoch: int, + batch_idx: int, + optimizer: Type[torch.optim.Optimizer], + ) -> None: + """Optimal way to set grads to zero.""" + optimizer.zero_grad(set_to_none=True) + + def _configure_optimizer(self) -> Type[torch.optim.Optimizer]: + """Configures the optimizer.""" + log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>") + return hydra.utils.instantiate( + self.optimizer_config, params=self.network.parameters() + ) + + def _configure_lr_schedulers( + self, optimizer: Type[torch.optim.Optimizer] + ) -> Optional[Dict[str, Any]]: + """Configures the lr scheduler.""" + log.info( + f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>" + ) + monitor = self.lr_scheduler_config.pop("monitor") + interval = self.lr_scheduler_config.pop("interval") + return { + "monitor": monitor, + "interval": interval, + "scheduler": hydra.utils.instantiate( + self.lr_scheduler_config, optimizer=optimizer + ), + } + + def configure_optimizers( + self, + ) -> Dict[str, Any]: + """Configures optimizer and lr scheduler.""" + optimizer = self._configure_optimizer() + if self.lr_scheduler_config is not None: + scheduler = self._configure_lr_schedulers(optimizer) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + return {"optimizer": optimizer} + + def forward(self, data: Tensor) -> Tensor: + """Feedforward pass.""" + return self.network(data) + + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + """Training step.""" + pass + + def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Validation step.""" + pass + + def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Test step.""" + pass diff --git a/text_recognizer/model/greedy_decoder.py b/text_recognizer/model/greedy_decoder.py new file mode 100644 index 0000000..5cbbb66 --- /dev/null +++ b/text_recognizer/model/greedy_decoder.py @@ -0,0 +1,58 @@ +"""Greedy decoder.""" +from typing import Type +from text_recognizer.data.tokenizer import Tokenizer +import torch +from torch import nn, Tensor + + +class GreedyDecoder: + def __init__( + self, + network: Type[nn.Module], + tokenizer: Tokenizer, + max_output_len: int = 682, + ) -> None: + self.network = network + self.start_index = tokenizer.start_index + self.end_index = tokenizer.end_index + self.pad_index = tokenizer.pad_index + self.max_output_len = max_output_len + + def __call__(self, x: Tensor) -> Tensor: + bsz = x.shape[0] + + # Encode image(s) to latent vectors. + img_features = self.network.encode(x) + + # Create a placeholder matrix for storing outputs from the network + indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) + indecies[:, 0] = self.start_index + + try: + for i in range(1, self.max_output_len): + tokens = indecies[:, :i] # (B, Sy) + logits = self.network.decode(tokens, img_features) # (B, C, Sy) + indecies_ = torch.argmax(logits, dim=1) # (B, Sy) + indecies[:, i : i + 1] = indecies_[:, -1:] + + # Early stopping of prediction loop if token is end or padding token. + if ( + (indecies[:, i - 1] == self.end_index) + | (indecies[:, i - 1] == self.pad_index) + ).all(): + break + + # Set all tokens after end token to pad token. + for i in range(1, self.max_output_len): + idx = (indecies[:, i - 1] == self.end_index) | ( + indecies[:, i - 1] == self.pad_index + ) + indecies[idx, i] = self.pad_index + + return indecies + except Exception: + # TODO: investigate this error more + print(x.shape) + # print(indecies) + print(indecies.shape) + print(img_features.shape) diff --git a/text_recognizer/model/transformer.py b/text_recognizer/model/transformer.py new file mode 100644 index 0000000..23b2a3a --- /dev/null +++ b/text_recognizer/model/transformer.py @@ -0,0 +1,89 @@ +"""Lightning model for transformer networks.""" +from typing import Callable, Optional, Sequence, Tuple, Type +from text_recognizer.model.greedy_decoder import GreedyDecoder + +import torch +from omegaconf import DictConfig +from torch import nn, Tensor +from torchmetrics import CharErrorRate, WordErrorRate + +from text_recognizer.data.tokenizer import Tokenizer +from text_recognizer.model.base import LitBase + + +class LitTransformer(LitBase): + def __init__( + self, + network: Type[nn.Module], + loss_fn: Type[nn.Module], + optimizer_config: DictConfig, + tokenizer: Tokenizer, + decoder: Callable = GreedyDecoder, + lr_scheduler_config: Optional[DictConfig] = None, + max_output_len: int = 682, + ) -> None: + super().__init__( + network, + loss_fn, + optimizer_config, + lr_scheduler_config, + tokenizer, + ) + self.max_output_len = max_output_len + self.val_cer = CharErrorRate() + self.test_cer = CharErrorRate() + self.val_wer = WordErrorRate() + self.test_wer = WordErrorRate() + self.decoder = decoder + + def forward(self, data: Tensor) -> Tensor: + """Autoregressive forward pass.""" + return self.predict(data) + + def teacher_forward(self, data: Tensor, targets: Tensor) -> Tensor: + """Non-autoregressive forward pass.""" + return self.network(data, targets) + + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + """Training step.""" + data, targets = batch + logits = self.teacher_forward(data, targets[:, :-1]) + loss = self.loss_fn(logits, targets[:, 1:]) + self.log("train/loss", loss, prog_bar=True) + return loss + + def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Validation step.""" + data, targets = batch + preds = self(data) + pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets) + + self.val_acc(preds, targets) + self.val_cer(pred_text, target_text) + self.val_wer(pred_text, target_text) + self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) + self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) + self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True) + + def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Test step.""" + data, targets = batch + preds = self(data) + pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets) + + self.test_acc(preds, targets) + self.test_cer(pred_text, target_text) + self.test_wer(pred_text, target_text) + self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) + self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) + self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True) + + def _to_tokens( + self, + indices: Tensor, + ) -> Sequence[str]: + return [self.tokenizer.decode(i) for i in indices] + + @torch.no_grad() + def predict(self, x: Tensor) -> Tensor: + return self.decoder(x) |