summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 18:21:13 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 18:21:13 +0200
commit03dae09c63f2079f37bbf25fd9ded6f20f1490da (patch)
tree0ce80f3a0c8fafaea2e3d2734e4ddc577c504b2d /text_recognizer/models
parent8207dc902db439d606fe36d726f0405aabbf173e (diff)
Add transformer lit model
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/base.py2
-rw-r--r--text_recognizer/models/transformer.py91
2 files changed, 92 insertions, 1 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 0c70625..46e5136 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -11,7 +11,7 @@ import torchmetrics
from text_recognizer import networks
-class BaseModel(pl.LightningModule):
+class LitBaseModel(pl.LightningModule):
"""Abstract PyTorch Lightning class."""
def __init__(
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
new file mode 100644
index 0000000..9b02e78
--- /dev/null
+++ b/text_recognizer/models/transformer.py
@@ -0,0 +1,91 @@
+"""PyTorch Lightning model for base Transformers."""
+from typing import Dict, List, Optional, Tuple
+
+import pytorch_lightning as pl
+import torch
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+import wandb
+
+from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.models.metrics import CharacterErrorRate
+from text_recognizer.models.base import LitBaseModel
+
+
+class LitTransformerModel(LitBaseModel):
+ """A PyTorch Lightning model for transformer networks."""
+
+ def __init__(
+ self,
+ network_args: Dict,
+ optimizer_args: Dict,
+ lr_scheduler_args: Dict,
+ criterion_args: Dict,
+ monitor: str = "val_loss",
+ mapping: Optional[List[str]] = None,
+ ) -> None:
+ super().__init__(
+ network_args,
+ optimizer_args,
+ lr_scheduler_args,
+ criterion_args,
+ monitor)
+
+ self.mapping, ignore_tokens = self.configure_mapping(mapping)
+ self.val_cer = CharacterErrorRate(ignore_tokens)
+ self.test_cer = CharacterErrorRate(ignore_tokens)
+
+ def forward(self, data: Tensor) -> Tensor:
+ """Forward pass with the transformer network."""
+ return self.network.predict(data)
+
+ @staticmethod
+ def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]:
+ """Configure mapping."""
+ mapping, inverse_mapping, _ = emnist_mapping()
+ start_index = inverse_mapping["<s>"]
+ end_index = inverse_mapping["<e>"]
+ pad_index = inverse_mapping["<p>"]
+ ignore_tokens = [start_index, end_index, pad_index]
+ # TODO: add case for sentence pieces
+ return mapping, ignore_tokens
+
+ def _log_prediction(self, data: Tensor, pred: Tensor) -> None:
+ """Logs prediction on image with wandb."""
+ pred_str = "".join(self.mapping[i] for i in pred[0].tolist() if i != 3) # pad index is 3
+ try:
+ self.logger.experiment.log({
+ "val_pred_examples": [wandb.Image(data[0], caption=pred_str)]
+ })
+ except AttributeError:
+ pass
+
+ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
+ """Training step."""
+ data, targets = batch
+ logits = self.network(data, targets[:, :-1])
+ loss = self.loss_fn(logits, targets[:, 1:])
+ self.log("train_loss", loss)
+ return loss
+
+ def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Validation step."""
+ data, targets = batch
+
+ logits = self.network(data, targets[:-1])
+ loss = self.loss_fn(logits, targets[1:])
+ self.log("val_loss", loss, prog_bar=True)
+
+ pred = self.network.predict(data)
+ self._log_prediction(data, pred)
+ self.val_cer(pred, targets)
+ self.log("val_cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
+
+ def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Test step."""
+ data, targets = batch
+ pred = self.network.predict(data)
+ self._log_prediction(data, pred)
+ self.test_cer(pred, targets)
+ self.log("test_cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)