summaryrefslogtreecommitdiff
path: root/text_recognizer/models/conformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/conformer.py')
-rw-r--r--text_recognizer/models/conformer.py125
1 files changed, 125 insertions, 0 deletions
diff --git a/text_recognizer/models/conformer.py b/text_recognizer/models/conformer.py
new file mode 100644
index 0000000..ee3d1e3
--- /dev/null
+++ b/text_recognizer/models/conformer.py
@@ -0,0 +1,125 @@
+"""Lightning Conformer model."""
+import itertools
+from typing import Optional, Tuple, Type
+
+from omegaconf import DictConfig
+import torch
+from torch import nn, Tensor
+
+from text_recognizer.data.mappings import AbstractMapping
+from text_recognizer.models.base import LitBase
+from text_recognizer.models.metrics import CharacterErrorRate
+from text_recognizer.models.util import first_element
+
+
+class LitConformer(LitBase):
+ """A PyTorch Lightning model for transformer networks."""
+
+ def __init__(
+ self,
+ network: Type[nn.Module],
+ loss_fn: Type[nn.Module],
+ optimizer_configs: DictConfig,
+ lr_scheduler_configs: Optional[DictConfig],
+ mapping: Type[AbstractMapping],
+ max_output_len: int = 451,
+ start_token: str = "<s>",
+ end_token: str = "<e>",
+ pad_token: str = "<p>",
+ blank_token: str = "<b>",
+ ) -> None:
+ super().__init__(
+ network, loss_fn, optimizer_configs, lr_scheduler_configs, mapping
+ )
+ self.max_output_len = max_output_len
+ self.start_token = start_token
+ self.end_token = end_token
+ self.pad_token = pad_token
+ self.blank_token = blank_token
+ self.start_index = int(self.mapping.get_index(self.start_token))
+ self.end_index = int(self.mapping.get_index(self.end_token))
+ self.pad_index = int(self.mapping.get_index(self.pad_token))
+ self.blank_index = int(self.mapping.get_index(self.blank_token))
+ self.ignore_indices = set(
+ [self.start_index, self.end_index, self.pad_index, self.blank_index]
+ )
+ self.val_cer = CharacterErrorRate(self.ignore_indices)
+ self.test_cer = CharacterErrorRate(self.ignore_indices)
+
+ @torch.no_grad()
+ def predict(self, x: Tensor) -> str:
+ """Predicts a sequence of characters."""
+ logits = self(x)
+ logprobs = torch.log_softmax(logits, dim=1)
+ pred = self.decode(logprobs, self.max_output_len)[0]
+ return "".join([self.mapping[i] for i in pred if i not in self.ignore_indices])
+
+ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
+ """Training step."""
+ data, targets = batch
+ logits = self(data)
+ logprobs = torch.log_softmax(logits, dim=1)
+ B, _, S = logprobs.shape
+ input_length = torch.ones(B).types_as(logprobs).int() * S
+ target_length = first_element(targets, self.pad_index).types_as(targets)
+ loss = self.loss_fn(
+ logprobs.permute(2, 0, 1), targets, input_length, target_length
+ )
+ self.log("train/loss", loss)
+ return loss
+
+ def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Validation step."""
+ data, targets = batch
+ logits = self(data)
+ logprobs = torch.log_softmax(logits, dim=1)
+ B, _, S = logprobs.shape
+ input_length = torch.ones(B).types_as(logprobs).int() * S
+ target_length = first_element(targets, self.pad_index).types_as(targets)
+ loss = self.loss_fn(
+ logprobs.permute(2, 0, 1), targets, input_length, target_length
+ )
+ self.log("val/loss", loss)
+ preds = self.decode(logprobs, targets.shape[1])
+ self.val_acc(preds, targets)
+ self.log("val/acc", self.val_acc, on_step=False, on_epoch=True)
+ self.val_cer(preds, targets)
+ self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
+
+ def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Test step."""
+ data, targets = batch
+ logits = self(data)
+ logprobs = torch.log_softmax(logits, dim=1)
+ preds = self.decode(logprobs, targets.shape[1])
+ self.val_acc(preds, targets)
+ self.log("val/acc", self.val_acc, on_step=False, on_epoch=True)
+ self.val_cer(preds, targets)
+ self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
+
+ def decode(self, logprobs: Tensor, max_length: int) -> Tensor:
+ """Greedly decodes a log prob sequence.
+
+ Args:
+ logprobs (Tensor): Log probabilities.
+ max_length (int): Max length of a sequence.
+
+ Shapes:
+ - x: :math: `(B, C, Y)`
+ - output: :math: `(B, S)`
+
+ Returns:
+ Tensor: A predicted sequence of characters.
+ """
+ B = logprobs.shape[0]
+ argmax = logprobs.argmax(1)
+ decoded = torch.ones((B, max_length)).types_as(logprobs).int() * self.pad_index
+ for i in range(B):
+ seq = [
+ b
+ for b, _ in itertools.groupby(argmax[i].tolist())
+ if b != self.blank_index
+ ][:max_length]
+ for j, c in enumerate(seq):
+ decoded[i, j] = c
+ return decoded