From 0421daf6bd97596703f426ba61c401599b538eeb Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 25 Aug 2023 23:18:31 +0200 Subject: Rename --- text_recognizer/models/__init__.py | 2 - text_recognizer/models/base.py | 97 ------------------------------- text_recognizer/models/greedy_decoder.py | 51 ---------------- text_recognizer/models/transformer.py | 99 -------------------------------- 4 files changed, 249 deletions(-) delete mode 100644 text_recognizer/models/__init__.py delete mode 100644 text_recognizer/models/base.py delete mode 100644 text_recognizer/models/greedy_decoder.py delete mode 100644 text_recognizer/models/transformer.py (limited to 'text_recognizer/models') diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py deleted file mode 100644 index cc02487..0000000 --- a/text_recognizer/models/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""PyTorch Lightning models modules.""" -from text_recognizer.models.transformer import LitTransformer diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py deleted file mode 100644 index 4dd5cdf..0000000 --- a/text_recognizer/models/base.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Base PyTorch Lightning model.""" -from typing import Any, Dict, Optional, Tuple, Type - -import hydra -import torch -from loguru import logger as log -from omegaconf import DictConfig -from pytorch_lightning import LightningModule -from torch import nn, Tensor -from torchmetrics import Accuracy - -from text_recognizer.data.tokenizer import Tokenizer - - -class LitBase(LightningModule): - """Abstract PyTorch Lightning class.""" - - def __init__( - self, - network: Type[nn.Module], - loss_fn: Type[nn.Module], - optimizer_config: DictConfig, - lr_scheduler_config: Optional[DictConfig], - tokenizer: Tokenizer, - ) -> None: - super().__init__() - - self.network = network - self.loss_fn = loss_fn - self.optimizer_config = optimizer_config - self.lr_scheduler_config = lr_scheduler_config - self.tokenizer = tokenizer - ignore_index = int(self.tokenizer.get_value("

")) - # Placeholders - self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) - self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) - self.test_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) - - def optimizer_zero_grad( - self, - epoch: int, - batch_idx: int, - optimizer: Type[torch.optim.Optimizer], - optimizer_idx: int, - ) -> None: - """Optimal way to set grads to zero.""" - optimizer.zero_grad(set_to_none=True) - - def _configure_optimizer(self) -> Type[torch.optim.Optimizer]: - """Configures the optimizer.""" - log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>") - return hydra.utils.instantiate( - self.optimizer_config, params=self.network.parameters() - ) - - def _configure_lr_schedulers( - self, optimizer: Type[torch.optim.Optimizer] - ) -> Optional[Dict[str, Any]]: - """Configures the lr scheduler.""" - log.info( - f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>" - ) - monitor = self.lr_scheduler_config.pop("monitor") - interval = self.lr_scheduler_config.pop("interval") - return { - "monitor": monitor, - "interval": interval, - "scheduler": hydra.utils.instantiate( - self.lr_scheduler_config, optimizer=optimizer - ), - } - - def configure_optimizers( - self, - ) -> Dict[str, Any]: - """Configures optimizer and lr scheduler.""" - optimizer = self._configure_optimizer() - if self.lr_scheduler_config is not None: - scheduler = self._configure_lr_schedulers(optimizer) - return {"optimizer": optimizer, "lr_scheduler": scheduler} - return {"optimizer": optimizer} - - def forward(self, data: Tensor) -> Tensor: - """Feedforward pass.""" - return self.network(data) - - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - """Training step.""" - pass - - def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Validation step.""" - pass - - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Test step.""" - pass diff --git a/text_recognizer/models/greedy_decoder.py b/text_recognizer/models/greedy_decoder.py deleted file mode 100644 index 9d2f192..0000000 --- a/text_recognizer/models/greedy_decoder.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Greedy decoder.""" -from typing import Type -from text_recognizer.data.tokenizer import Tokenizer -import torch -from torch import nn, Tensor - - -class GreedyDecoder: - def __init__( - self, - network: Type[nn.Module], - tokenizer: Tokenizer, - max_output_len: int = 682, - ) -> None: - self.network = network - self.start_index = tokenizer.start_index - self.end_index = tokenizer.end_index - self.pad_index = tokenizer.pad_index - self.max_output_len = max_output_len - - def __call__(self, x: Tensor) -> Tensor: - bsz = x.shape[0] - - # Encode image(s) to latent vectors. - img_features = self.network.encode(x) - - # Create a placeholder matrix for storing outputs from the network - indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) - indecies[:, 0] = self.start_index - - for Sy in range(1, self.max_output_len): - tokens = indecies[:, :Sy] # (B, Sy) - logits = self.network.decode(tokens, img_features) # (B, C, Sy) - indecies_ = torch.argmax(logits, dim=1) # (B, Sy) - indecies[:, Sy : Sy + 1] = indecies_[:, -1:] - - # Early stopping of prediction loop if token is end or padding token. - if ( - (indecies[:, Sy - 1] == self.end_index) - | (indecies[:, Sy - 1] == self.pad_index) - ).all(): - break - - # Set all tokens after end token to pad token. - for Sy in range(1, self.max_output_len): - idx = (indecies[:, Sy - 1] == self.end_index) | ( - indecies[:, Sy - 1] == self.pad_index - ) - indecies[idx, Sy] = self.pad_index - - return indecies diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py deleted file mode 100644 index bbfaac0..0000000 --- a/text_recognizer/models/transformer.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Lightning model for base Transformers.""" -from typing import Callable, Optional, Sequence, Tuple, Type -from text_recognizer.models.greedy_decoder import GreedyDecoder - -import torch -from omegaconf import DictConfig -from torch import nn, Tensor -from torchmetrics import CharErrorRate, WordErrorRate - -from text_recognizer.data.tokenizer import Tokenizer -from text_recognizer.models.base import LitBase - - -class LitTransformer(LitBase): - """A PyTorch Lightning model for transformer networks.""" - - def __init__( - self, - network: Type[nn.Module], - loss_fn: Type[nn.Module], - optimizer_config: DictConfig, - tokenizer: Tokenizer, - decoder: Callable = GreedyDecoder, - lr_scheduler_config: Optional[DictConfig] = None, - max_output_len: int = 682, - ) -> None: - super().__init__( - network, - loss_fn, - optimizer_config, - lr_scheduler_config, - tokenizer, - ) - self.max_output_len = max_output_len - self.val_cer = CharErrorRate() - self.test_cer = CharErrorRate() - self.val_wer = WordErrorRate() - self.test_wer = WordErrorRate() - self.decoder = decoder - - def forward(self, data: Tensor) -> Tensor: - """Autoregressive forward pass.""" - return self.predict(data) - - def teacher_forward(self, data: Tensor, targets: Tensor) -> Tensor: - """Non-autoregressive forward pass.""" - return self.network(data, targets) - - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - """Training step.""" - data, targets = batch - logits = self.teacher_forward(data, targets[:, :-1]) - loss = self.loss_fn(logits, targets[:, 1:]) - self.log("train/loss", loss) - return loss - - def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Validation step.""" - data, targets = batch - - logits = self.teacher_forward(data, targets[:, :-1]) - loss = self.loss_fn(logits, targets[:, 1:]) - preds = self.predict(data) - pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets) - - self.val_acc(preds, targets) - self.val_cer(pred_text, target_text) - self.val_wer(pred_text, target_text) - self.log("val/loss", loss, on_step=False, on_epoch=True) - self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) - self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) - self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True) - - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Test step.""" - data, targets = batch - - logits = self.teacher_forward(data, targets[:, :-1]) - loss = self.loss_fn(logits, targets[:, 1:]) - preds = self(data) - pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets) - - self.test_acc(preds, targets) - self.test_cer(pred_text, target_text) - self.test_wer(pred_text, target_text) - self.log("test/loss", loss, on_step=False, on_epoch=True) - self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) - self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) - self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True) - - def _to_tokens( - self, - indecies: Tensor, - ) -> Sequence[str]: - return [self.tokenizer.decode(i) for i in indecies] - - @torch.no_grad() - def predict(self, x: Tensor) -> Tensor: - return self.decoder(x) -- cgit v1.2.3-70-g09d2