diff options
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/base.py | 2 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 91 |
2 files changed, 92 insertions, 1 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 0c70625..46e5136 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -11,7 +11,7 @@ import torchmetrics from text_recognizer import networks -class BaseModel(pl.LightningModule): +class LitBaseModel(pl.LightningModule): """Abstract PyTorch Lightning class.""" def __init__( diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py new file mode 100644 index 0000000..9b02e78 --- /dev/null +++ b/text_recognizer/models/transformer.py @@ -0,0 +1,91 @@ +"""PyTorch Lightning model for base Transformers.""" +from typing import Dict, List, Optional, Tuple + +import pytorch_lightning as pl +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F +import wandb + +from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.models.metrics import CharacterErrorRate +from text_recognizer.models.base import LitBaseModel + + +class LitTransformerModel(LitBaseModel): + """A PyTorch Lightning model for transformer networks.""" + + def __init__( + self, + network_args: Dict, + optimizer_args: Dict, + lr_scheduler_args: Dict, + criterion_args: Dict, + monitor: str = "val_loss", + mapping: Optional[List[str]] = None, + ) -> None: + super().__init__( + network_args, + optimizer_args, + lr_scheduler_args, + criterion_args, + monitor) + + self.mapping, ignore_tokens = self.configure_mapping(mapping) + self.val_cer = CharacterErrorRate(ignore_tokens) + self.test_cer = CharacterErrorRate(ignore_tokens) + + def forward(self, data: Tensor) -> Tensor: + """Forward pass with the transformer network.""" + return self.network.predict(data) + + @staticmethod + def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]: + """Configure mapping.""" + mapping, inverse_mapping, _ = emnist_mapping() + start_index = inverse_mapping["<s>"] + end_index = inverse_mapping["<e>"] + pad_index = inverse_mapping["<p>"] + ignore_tokens = [start_index, end_index, pad_index] + # TODO: add case for sentence pieces + return mapping, ignore_tokens + + def _log_prediction(self, data: Tensor, pred: Tensor) -> None: + """Logs prediction on image with wandb.""" + pred_str = "".join(self.mapping[i] for i in pred[0].tolist() if i != 3) # pad index is 3 + try: + self.logger.experiment.log({ + "val_pred_examples": [wandb.Image(data[0], caption=pred_str)] + }) + except AttributeError: + pass + + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + """Training step.""" + data, targets = batch + logits = self.network(data, targets[:, :-1]) + loss = self.loss_fn(logits, targets[:, 1:]) + self.log("train_loss", loss) + return loss + + def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Validation step.""" + data, targets = batch + + logits = self.network(data, targets[:-1]) + loss = self.loss_fn(logits, targets[1:]) + self.log("val_loss", loss, prog_bar=True) + + pred = self.network.predict(data) + self._log_prediction(data, pred) + self.val_cer(pred, targets) + self.log("val_cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) + + def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Test step.""" + data, targets = batch + pred = self.network.predict(data) + self._log_prediction(data, pred) + self.test_cer(pred, targets) + self.log("test_cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) |