diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 18:21:13 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 18:21:13 +0200 |
commit | 03dae09c63f2079f37bbf25fd9ded6f20f1490da (patch) | |
tree | 0ce80f3a0c8fafaea2e3d2734e4ddc577c504b2d /text_recognizer/models/transformer.py | |
parent | 8207dc902db439d606fe36d726f0405aabbf173e (diff) |
Add transformer lit model
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r-- | text_recognizer/models/transformer.py | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py new file mode 100644 index 0000000..9b02e78 --- /dev/null +++ b/text_recognizer/models/transformer.py @@ -0,0 +1,91 @@ +"""PyTorch Lightning model for base Transformers.""" +from typing import Dict, List, Optional, Tuple + +import pytorch_lightning as pl +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F +import wandb + +from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.models.metrics import CharacterErrorRate +from text_recognizer.models.base import LitBaseModel + + +class LitTransformerModel(LitBaseModel): + """A PyTorch Lightning model for transformer networks.""" + + def __init__( + self, + network_args: Dict, + optimizer_args: Dict, + lr_scheduler_args: Dict, + criterion_args: Dict, + monitor: str = "val_loss", + mapping: Optional[List[str]] = None, + ) -> None: + super().__init__( + network_args, + optimizer_args, + lr_scheduler_args, + criterion_args, + monitor) + + self.mapping, ignore_tokens = self.configure_mapping(mapping) + self.val_cer = CharacterErrorRate(ignore_tokens) + self.test_cer = CharacterErrorRate(ignore_tokens) + + def forward(self, data: Tensor) -> Tensor: + """Forward pass with the transformer network.""" + return self.network.predict(data) + + @staticmethod + def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]: + """Configure mapping.""" + mapping, inverse_mapping, _ = emnist_mapping() + start_index = inverse_mapping["<s>"] + end_index = inverse_mapping["<e>"] + pad_index = inverse_mapping["<p>"] + ignore_tokens = [start_index, end_index, pad_index] + # TODO: add case for sentence pieces + return mapping, ignore_tokens + + def _log_prediction(self, data: Tensor, pred: Tensor) -> None: + """Logs prediction on image with wandb.""" + pred_str = "".join(self.mapping[i] for i in pred[0].tolist() if i != 3) # pad index is 3 + try: + self.logger.experiment.log({ + "val_pred_examples": [wandb.Image(data[0], caption=pred_str)] + }) + except AttributeError: + pass + + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + """Training step.""" + data, targets = batch + logits = self.network(data, targets[:, :-1]) + loss = self.loss_fn(logits, targets[:, 1:]) + self.log("train_loss", loss) + return loss + + def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Validation step.""" + data, targets = batch + + logits = self.network(data, targets[:-1]) + loss = self.loss_fn(logits, targets[1:]) + self.log("val_loss", loss, prog_bar=True) + + pred = self.network.predict(data) + self._log_prediction(data, pred) + self.val_cer(pred, targets) + self.log("val_cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) + + def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Test step.""" + data, targets = batch + pred = self.network.predict(data) + self._log_prediction(data, pred) + self.test_cer(pred, targets) + self.log("test_cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) |