summaryrefslogtreecommitdiff
path: root/text_recognizer/model
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/model')
-rw-r--r--text_recognizer/model/__init__.py1
-rw-r--r--text_recognizer/model/base.py96
-rw-r--r--text_recognizer/model/greedy_decoder.py58
-rw-r--r--text_recognizer/model/transformer.py89
4 files changed, 244 insertions, 0 deletions
diff --git a/text_recognizer/model/__init__.py b/text_recognizer/model/__init__.py
new file mode 100644
index 0000000..1982daf
--- /dev/null
+++ b/text_recognizer/model/__init__.py
@@ -0,0 +1 @@
+"""PyTorch Lightning models modules."""
diff --git a/text_recognizer/model/base.py b/text_recognizer/model/base.py
new file mode 100644
index 0000000..1cff796
--- /dev/null
+++ b/text_recognizer/model/base.py
@@ -0,0 +1,96 @@
+"""Base PyTorch Lightning model."""
+from typing import Any, Dict, Optional, Tuple, Type
+
+import hydra
+import torch
+from loguru import logger as log
+from omegaconf import DictConfig
+import pytorch_lightning as L
+from torch import nn, Tensor
+from torchmetrics import Accuracy
+
+from text_recognizer.data.tokenizer import Tokenizer
+
+
+class LitBase(L.LightningModule):
+ """Abstract PyTorch Lightning class."""
+
+ def __init__(
+ self,
+ network: Type[nn.Module],
+ loss_fn: Type[nn.Module],
+ optimizer_config: DictConfig,
+ lr_scheduler_config: Optional[DictConfig],
+ tokenizer: Tokenizer,
+ ) -> None:
+ super().__init__()
+
+ self.network = network
+ self.loss_fn = loss_fn
+ self.optimizer_config = optimizer_config
+ self.lr_scheduler_config = lr_scheduler_config
+ self.tokenizer = tokenizer
+ ignore_index = int(self.tokenizer.get_value("<p>"))
+ # Placeholders
+ self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index)
+ self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index)
+ self.test_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index)
+
+ def optimizer_zero_grad(
+ self,
+ epoch: int,
+ batch_idx: int,
+ optimizer: Type[torch.optim.Optimizer],
+ ) -> None:
+ """Optimal way to set grads to zero."""
+ optimizer.zero_grad(set_to_none=True)
+
+ def _configure_optimizer(self) -> Type[torch.optim.Optimizer]:
+ """Configures the optimizer."""
+ log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>")
+ return hydra.utils.instantiate(
+ self.optimizer_config, params=self.network.parameters()
+ )
+
+ def _configure_lr_schedulers(
+ self, optimizer: Type[torch.optim.Optimizer]
+ ) -> Optional[Dict[str, Any]]:
+ """Configures the lr scheduler."""
+ log.info(
+ f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>"
+ )
+ monitor = self.lr_scheduler_config.pop("monitor")
+ interval = self.lr_scheduler_config.pop("interval")
+ return {
+ "monitor": monitor,
+ "interval": interval,
+ "scheduler": hydra.utils.instantiate(
+ self.lr_scheduler_config, optimizer=optimizer
+ ),
+ }
+
+ def configure_optimizers(
+ self,
+ ) -> Dict[str, Any]:
+ """Configures optimizer and lr scheduler."""
+ optimizer = self._configure_optimizer()
+ if self.lr_scheduler_config is not None:
+ scheduler = self._configure_lr_schedulers(optimizer)
+ return {"optimizer": optimizer, "lr_scheduler": scheduler}
+ return {"optimizer": optimizer}
+
+ def forward(self, data: Tensor) -> Tensor:
+ """Feedforward pass."""
+ return self.network(data)
+
+ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
+ """Training step."""
+ pass
+
+ def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Validation step."""
+ pass
+
+ def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Test step."""
+ pass
diff --git a/text_recognizer/model/greedy_decoder.py b/text_recognizer/model/greedy_decoder.py
new file mode 100644
index 0000000..5cbbb66
--- /dev/null
+++ b/text_recognizer/model/greedy_decoder.py
@@ -0,0 +1,58 @@
+"""Greedy decoder."""
+from typing import Type
+from text_recognizer.data.tokenizer import Tokenizer
+import torch
+from torch import nn, Tensor
+
+
+class GreedyDecoder:
+ def __init__(
+ self,
+ network: Type[nn.Module],
+ tokenizer: Tokenizer,
+ max_output_len: int = 682,
+ ) -> None:
+ self.network = network
+ self.start_index = tokenizer.start_index
+ self.end_index = tokenizer.end_index
+ self.pad_index = tokenizer.pad_index
+ self.max_output_len = max_output_len
+
+ def __call__(self, x: Tensor) -> Tensor:
+ bsz = x.shape[0]
+
+ # Encode image(s) to latent vectors.
+ img_features = self.network.encode(x)
+
+ # Create a placeholder matrix for storing outputs from the network
+ indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
+ indecies[:, 0] = self.start_index
+
+ try:
+ for i in range(1, self.max_output_len):
+ tokens = indecies[:, :i] # (B, Sy)
+ logits = self.network.decode(tokens, img_features) # (B, C, Sy)
+ indecies_ = torch.argmax(logits, dim=1) # (B, Sy)
+ indecies[:, i : i + 1] = indecies_[:, -1:]
+
+ # Early stopping of prediction loop if token is end or padding token.
+ if (
+ (indecies[:, i - 1] == self.end_index)
+ | (indecies[:, i - 1] == self.pad_index)
+ ).all():
+ break
+
+ # Set all tokens after end token to pad token.
+ for i in range(1, self.max_output_len):
+ idx = (indecies[:, i - 1] == self.end_index) | (
+ indecies[:, i - 1] == self.pad_index
+ )
+ indecies[idx, i] = self.pad_index
+
+ return indecies
+ except Exception:
+ # TODO: investigate this error more
+ print(x.shape)
+ # print(indecies)
+ print(indecies.shape)
+ print(img_features.shape)
diff --git a/text_recognizer/model/transformer.py b/text_recognizer/model/transformer.py
new file mode 100644
index 0000000..23b2a3a
--- /dev/null
+++ b/text_recognizer/model/transformer.py
@@ -0,0 +1,89 @@
+"""Lightning model for transformer networks."""
+from typing import Callable, Optional, Sequence, Tuple, Type
+from text_recognizer.model.greedy_decoder import GreedyDecoder
+
+import torch
+from omegaconf import DictConfig
+from torch import nn, Tensor
+from torchmetrics import CharErrorRate, WordErrorRate
+
+from text_recognizer.data.tokenizer import Tokenizer
+from text_recognizer.model.base import LitBase
+
+
+class LitTransformer(LitBase):
+ def __init__(
+ self,
+ network: Type[nn.Module],
+ loss_fn: Type[nn.Module],
+ optimizer_config: DictConfig,
+ tokenizer: Tokenizer,
+ decoder: Callable = GreedyDecoder,
+ lr_scheduler_config: Optional[DictConfig] = None,
+ max_output_len: int = 682,
+ ) -> None:
+ super().__init__(
+ network,
+ loss_fn,
+ optimizer_config,
+ lr_scheduler_config,
+ tokenizer,
+ )
+ self.max_output_len = max_output_len
+ self.val_cer = CharErrorRate()
+ self.test_cer = CharErrorRate()
+ self.val_wer = WordErrorRate()
+ self.test_wer = WordErrorRate()
+ self.decoder = decoder
+
+ def forward(self, data: Tensor) -> Tensor:
+ """Autoregressive forward pass."""
+ return self.predict(data)
+
+ def teacher_forward(self, data: Tensor, targets: Tensor) -> Tensor:
+ """Non-autoregressive forward pass."""
+ return self.network(data, targets)
+
+ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
+ """Training step."""
+ data, targets = batch
+ logits = self.teacher_forward(data, targets[:, :-1])
+ loss = self.loss_fn(logits, targets[:, 1:])
+ self.log("train/loss", loss, prog_bar=True)
+ return loss
+
+ def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Validation step."""
+ data, targets = batch
+ preds = self(data)
+ pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets)
+
+ self.val_acc(preds, targets)
+ self.val_cer(pred_text, target_text)
+ self.val_wer(pred_text, target_text)
+ self.log("val/acc", self.val_acc, on_step=False, on_epoch=True)
+ self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True)
+
+ def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Test step."""
+ data, targets = batch
+ preds = self(data)
+ pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets)
+
+ self.test_acc(preds, targets)
+ self.test_cer(pred_text, target_text)
+ self.test_wer(pred_text, target_text)
+ self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
+ self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True)
+
+ def _to_tokens(
+ self,
+ indices: Tensor,
+ ) -> Sequence[str]:
+ return [self.tokenizer.decode(i) for i in indices]
+
+ @torch.no_grad()
+ def predict(self, x: Tensor) -> Tensor:
+ return self.decoder(x)