From 0421daf6bd97596703f426ba61c401599b538eeb Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 25 Aug 2023 23:18:31 +0200 Subject: Rename --- text_recognizer/model/__init__.py | 1 + text_recognizer/model/base.py | 96 +++++++++++++++++++++++++++++++ text_recognizer/model/greedy_decoder.py | 58 +++++++++++++++++++ text_recognizer/model/transformer.py | 89 ++++++++++++++++++++++++++++ text_recognizer/models/__init__.py | 2 - text_recognizer/models/base.py | 97 ------------------------------- text_recognizer/models/greedy_decoder.py | 51 ---------------- text_recognizer/models/transformer.py | 99 -------------------------------- 8 files changed, 244 insertions(+), 249 deletions(-) create mode 100644 text_recognizer/model/__init__.py create mode 100644 text_recognizer/model/base.py create mode 100644 text_recognizer/model/greedy_decoder.py create mode 100644 text_recognizer/model/transformer.py delete mode 100644 text_recognizer/models/__init__.py delete mode 100644 text_recognizer/models/base.py delete mode 100644 text_recognizer/models/greedy_decoder.py delete mode 100644 text_recognizer/models/transformer.py diff --git a/text_recognizer/model/__init__.py b/text_recognizer/model/__init__.py new file mode 100644 index 0000000..1982daf --- /dev/null +++ b/text_recognizer/model/__init__.py @@ -0,0 +1 @@ +"""PyTorch Lightning models modules.""" diff --git a/text_recognizer/model/base.py b/text_recognizer/model/base.py new file mode 100644 index 0000000..1cff796 --- /dev/null +++ b/text_recognizer/model/base.py @@ -0,0 +1,96 @@ +"""Base PyTorch Lightning model.""" +from typing import Any, Dict, Optional, Tuple, Type + +import hydra +import torch +from loguru import logger as log +from omegaconf import DictConfig +import pytorch_lightning as L +from torch import nn, Tensor +from torchmetrics import Accuracy + +from text_recognizer.data.tokenizer import Tokenizer + + +class LitBase(L.LightningModule): + """Abstract PyTorch Lightning class.""" + + def __init__( + self, + network: Type[nn.Module], + loss_fn: Type[nn.Module], + optimizer_config: DictConfig, + lr_scheduler_config: Optional[DictConfig], + tokenizer: Tokenizer, + ) -> None: + super().__init__() + + self.network = network + self.loss_fn = loss_fn + self.optimizer_config = optimizer_config + self.lr_scheduler_config = lr_scheduler_config + self.tokenizer = tokenizer + ignore_index = int(self.tokenizer.get_value("

")) + # Placeholders + self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) + self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) + self.test_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) + + def optimizer_zero_grad( + self, + epoch: int, + batch_idx: int, + optimizer: Type[torch.optim.Optimizer], + ) -> None: + """Optimal way to set grads to zero.""" + optimizer.zero_grad(set_to_none=True) + + def _configure_optimizer(self) -> Type[torch.optim.Optimizer]: + """Configures the optimizer.""" + log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>") + return hydra.utils.instantiate( + self.optimizer_config, params=self.network.parameters() + ) + + def _configure_lr_schedulers( + self, optimizer: Type[torch.optim.Optimizer] + ) -> Optional[Dict[str, Any]]: + """Configures the lr scheduler.""" + log.info( + f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>" + ) + monitor = self.lr_scheduler_config.pop("monitor") + interval = self.lr_scheduler_config.pop("interval") + return { + "monitor": monitor, + "interval": interval, + "scheduler": hydra.utils.instantiate( + self.lr_scheduler_config, optimizer=optimizer + ), + } + + def configure_optimizers( + self, + ) -> Dict[str, Any]: + """Configures optimizer and lr scheduler.""" + optimizer = self._configure_optimizer() + if self.lr_scheduler_config is not None: + scheduler = self._configure_lr_schedulers(optimizer) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + return {"optimizer": optimizer} + + def forward(self, data: Tensor) -> Tensor: + """Feedforward pass.""" + return self.network(data) + + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + """Training step.""" + pass + + def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Validation step.""" + pass + + def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Test step.""" + pass diff --git a/text_recognizer/model/greedy_decoder.py b/text_recognizer/model/greedy_decoder.py new file mode 100644 index 0000000..5cbbb66 --- /dev/null +++ b/text_recognizer/model/greedy_decoder.py @@ -0,0 +1,58 @@ +"""Greedy decoder.""" +from typing import Type +from text_recognizer.data.tokenizer import Tokenizer +import torch +from torch import nn, Tensor + + +class GreedyDecoder: + def __init__( + self, + network: Type[nn.Module], + tokenizer: Tokenizer, + max_output_len: int = 682, + ) -> None: + self.network = network + self.start_index = tokenizer.start_index + self.end_index = tokenizer.end_index + self.pad_index = tokenizer.pad_index + self.max_output_len = max_output_len + + def __call__(self, x: Tensor) -> Tensor: + bsz = x.shape[0] + + # Encode image(s) to latent vectors. + img_features = self.network.encode(x) + + # Create a placeholder matrix for storing outputs from the network + indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) + indecies[:, 0] = self.start_index + + try: + for i in range(1, self.max_output_len): + tokens = indecies[:, :i] # (B, Sy) + logits = self.network.decode(tokens, img_features) # (B, C, Sy) + indecies_ = torch.argmax(logits, dim=1) # (B, Sy) + indecies[:, i : i + 1] = indecies_[:, -1:] + + # Early stopping of prediction loop if token is end or padding token. + if ( + (indecies[:, i - 1] == self.end_index) + | (indecies[:, i - 1] == self.pad_index) + ).all(): + break + + # Set all tokens after end token to pad token. + for i in range(1, self.max_output_len): + idx = (indecies[:, i - 1] == self.end_index) | ( + indecies[:, i - 1] == self.pad_index + ) + indecies[idx, i] = self.pad_index + + return indecies + except Exception: + # TODO: investigate this error more + print(x.shape) + # print(indecies) + print(indecies.shape) + print(img_features.shape) diff --git a/text_recognizer/model/transformer.py b/text_recognizer/model/transformer.py new file mode 100644 index 0000000..23b2a3a --- /dev/null +++ b/text_recognizer/model/transformer.py @@ -0,0 +1,89 @@ +"""Lightning model for transformer networks.""" +from typing import Callable, Optional, Sequence, Tuple, Type +from text_recognizer.model.greedy_decoder import GreedyDecoder + +import torch +from omegaconf import DictConfig +from torch import nn, Tensor +from torchmetrics import CharErrorRate, WordErrorRate + +from text_recognizer.data.tokenizer import Tokenizer +from text_recognizer.model.base import LitBase + + +class LitTransformer(LitBase): + def __init__( + self, + network: Type[nn.Module], + loss_fn: Type[nn.Module], + optimizer_config: DictConfig, + tokenizer: Tokenizer, + decoder: Callable = GreedyDecoder, + lr_scheduler_config: Optional[DictConfig] = None, + max_output_len: int = 682, + ) -> None: + super().__init__( + network, + loss_fn, + optimizer_config, + lr_scheduler_config, + tokenizer, + ) + self.max_output_len = max_output_len + self.val_cer = CharErrorRate() + self.test_cer = CharErrorRate() + self.val_wer = WordErrorRate() + self.test_wer = WordErrorRate() + self.decoder = decoder + + def forward(self, data: Tensor) -> Tensor: + """Autoregressive forward pass.""" + return self.predict(data) + + def teacher_forward(self, data: Tensor, targets: Tensor) -> Tensor: + """Non-autoregressive forward pass.""" + return self.network(data, targets) + + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + """Training step.""" + data, targets = batch + logits = self.teacher_forward(data, targets[:, :-1]) + loss = self.loss_fn(logits, targets[:, 1:]) + self.log("train/loss", loss, prog_bar=True) + return loss + + def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Validation step.""" + data, targets = batch + preds = self(data) + pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets) + + self.val_acc(preds, targets) + self.val_cer(pred_text, target_text) + self.val_wer(pred_text, target_text) + self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) + self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) + self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True) + + def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Test step.""" + data, targets = batch + preds = self(data) + pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets) + + self.test_acc(preds, targets) + self.test_cer(pred_text, target_text) + self.test_wer(pred_text, target_text) + self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) + self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) + self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True) + + def _to_tokens( + self, + indices: Tensor, + ) -> Sequence[str]: + return [self.tokenizer.decode(i) for i in indices] + + @torch.no_grad() + def predict(self, x: Tensor) -> Tensor: + return self.decoder(x) diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py deleted file mode 100644 index cc02487..0000000 --- a/text_recognizer/models/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""PyTorch Lightning models modules.""" -from text_recognizer.models.transformer import LitTransformer diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py deleted file mode 100644 index 4dd5cdf..0000000 --- a/text_recognizer/models/base.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Base PyTorch Lightning model.""" -from typing import Any, Dict, Optional, Tuple, Type - -import hydra -import torch -from loguru import logger as log -from omegaconf import DictConfig -from pytorch_lightning import LightningModule -from torch import nn, Tensor -from torchmetrics import Accuracy - -from text_recognizer.data.tokenizer import Tokenizer - - -class LitBase(LightningModule): - """Abstract PyTorch Lightning class.""" - - def __init__( - self, - network: Type[nn.Module], - loss_fn: Type[nn.Module], - optimizer_config: DictConfig, - lr_scheduler_config: Optional[DictConfig], - tokenizer: Tokenizer, - ) -> None: - super().__init__() - - self.network = network - self.loss_fn = loss_fn - self.optimizer_config = optimizer_config - self.lr_scheduler_config = lr_scheduler_config - self.tokenizer = tokenizer - ignore_index = int(self.tokenizer.get_value("

")) - # Placeholders - self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) - self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) - self.test_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) - - def optimizer_zero_grad( - self, - epoch: int, - batch_idx: int, - optimizer: Type[torch.optim.Optimizer], - optimizer_idx: int, - ) -> None: - """Optimal way to set grads to zero.""" - optimizer.zero_grad(set_to_none=True) - - def _configure_optimizer(self) -> Type[torch.optim.Optimizer]: - """Configures the optimizer.""" - log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>") - return hydra.utils.instantiate( - self.optimizer_config, params=self.network.parameters() - ) - - def _configure_lr_schedulers( - self, optimizer: Type[torch.optim.Optimizer] - ) -> Optional[Dict[str, Any]]: - """Configures the lr scheduler.""" - log.info( - f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>" - ) - monitor = self.lr_scheduler_config.pop("monitor") - interval = self.lr_scheduler_config.pop("interval") - return { - "monitor": monitor, - "interval": interval, - "scheduler": hydra.utils.instantiate( - self.lr_scheduler_config, optimizer=optimizer - ), - } - - def configure_optimizers( - self, - ) -> Dict[str, Any]: - """Configures optimizer and lr scheduler.""" - optimizer = self._configure_optimizer() - if self.lr_scheduler_config is not None: - scheduler = self._configure_lr_schedulers(optimizer) - return {"optimizer": optimizer, "lr_scheduler": scheduler} - return {"optimizer": optimizer} - - def forward(self, data: Tensor) -> Tensor: - """Feedforward pass.""" - return self.network(data) - - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - """Training step.""" - pass - - def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Validation step.""" - pass - - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Test step.""" - pass diff --git a/text_recognizer/models/greedy_decoder.py b/text_recognizer/models/greedy_decoder.py deleted file mode 100644 index 9d2f192..0000000 --- a/text_recognizer/models/greedy_decoder.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Greedy decoder.""" -from typing import Type -from text_recognizer.data.tokenizer import Tokenizer -import torch -from torch import nn, Tensor - - -class GreedyDecoder: - def __init__( - self, - network: Type[nn.Module], - tokenizer: Tokenizer, - max_output_len: int = 682, - ) -> None: - self.network = network - self.start_index = tokenizer.start_index - self.end_index = tokenizer.end_index - self.pad_index = tokenizer.pad_index - self.max_output_len = max_output_len - - def __call__(self, x: Tensor) -> Tensor: - bsz = x.shape[0] - - # Encode image(s) to latent vectors. - img_features = self.network.encode(x) - - # Create a placeholder matrix for storing outputs from the network - indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) - indecies[:, 0] = self.start_index - - for Sy in range(1, self.max_output_len): - tokens = indecies[:, :Sy] # (B, Sy) - logits = self.network.decode(tokens, img_features) # (B, C, Sy) - indecies_ = torch.argmax(logits, dim=1) # (B, Sy) - indecies[:, Sy : Sy + 1] = indecies_[:, -1:] - - # Early stopping of prediction loop if token is end or padding token. - if ( - (indecies[:, Sy - 1] == self.end_index) - | (indecies[:, Sy - 1] == self.pad_index) - ).all(): - break - - # Set all tokens after end token to pad token. - for Sy in range(1, self.max_output_len): - idx = (indecies[:, Sy - 1] == self.end_index) | ( - indecies[:, Sy - 1] == self.pad_index - ) - indecies[idx, Sy] = self.pad_index - - return indecies diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py deleted file mode 100644 index bbfaac0..0000000 --- a/text_recognizer/models/transformer.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Lightning model for base Transformers.""" -from typing import Callable, Optional, Sequence, Tuple, Type -from text_recognizer.models.greedy_decoder import GreedyDecoder - -import torch -from omegaconf import DictConfig -from torch import nn, Tensor -from torchmetrics import CharErrorRate, WordErrorRate - -from text_recognizer.data.tokenizer import Tokenizer -from text_recognizer.models.base import LitBase - - -class LitTransformer(LitBase): - """A PyTorch Lightning model for transformer networks.""" - - def __init__( - self, - network: Type[nn.Module], - loss_fn: Type[nn.Module], - optimizer_config: DictConfig, - tokenizer: Tokenizer, - decoder: Callable = GreedyDecoder, - lr_scheduler_config: Optional[DictConfig] = None, - max_output_len: int = 682, - ) -> None: - super().__init__( - network, - loss_fn, - optimizer_config, - lr_scheduler_config, - tokenizer, - ) - self.max_output_len = max_output_len - self.val_cer = CharErrorRate() - self.test_cer = CharErrorRate() - self.val_wer = WordErrorRate() - self.test_wer = WordErrorRate() - self.decoder = decoder - - def forward(self, data: Tensor) -> Tensor: - """Autoregressive forward pass.""" - return self.predict(data) - - def teacher_forward(self, data: Tensor, targets: Tensor) -> Tensor: - """Non-autoregressive forward pass.""" - return self.network(data, targets) - - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - """Training step.""" - data, targets = batch - logits = self.teacher_forward(data, targets[:, :-1]) - loss = self.loss_fn(logits, targets[:, 1:]) - self.log("train/loss", loss) - return loss - - def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Validation step.""" - data, targets = batch - - logits = self.teacher_forward(data, targets[:, :-1]) - loss = self.loss_fn(logits, targets[:, 1:]) - preds = self.predict(data) - pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets) - - self.val_acc(preds, targets) - self.val_cer(pred_text, target_text) - self.val_wer(pred_text, target_text) - self.log("val/loss", loss, on_step=False, on_epoch=True) - self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) - self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) - self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True) - - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Test step.""" - data, targets = batch - - logits = self.teacher_forward(data, targets[:, :-1]) - loss = self.loss_fn(logits, targets[:, 1:]) - preds = self(data) - pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets) - - self.test_acc(preds, targets) - self.test_cer(pred_text, target_text) - self.test_wer(pred_text, target_text) - self.log("test/loss", loss, on_step=False, on_epoch=True) - self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) - self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) - self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True) - - def _to_tokens( - self, - indecies: Tensor, - ) -> Sequence[str]: - return [self.tokenizer.decode(i) for i in indecies] - - @torch.no_grad() - def predict(self, x: Tensor) -> Tensor: - return self.decoder(x) -- cgit v1.2.3-70-g09d2