From 579f3edc3e20ddbe8207ee0c4189a270b2dfedc1 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 2 Sep 2023 01:52:44 +0200 Subject: Refactor lit models --- text_recognizer/model/base.py | 10 ++++++ text_recognizer/model/transformer.py | 59 ++++++++++++++++++++++++------------ 2 files changed, 49 insertions(+), 20 deletions(-) (limited to 'text_recognizer') diff --git a/text_recognizer/model/base.py b/text_recognizer/model/base.py index 1cff796..adcb8da 100644 --- a/text_recognizer/model/base.py +++ b/text_recognizer/model/base.py @@ -94,3 +94,13 @@ class LitBase(L.LightningModule): def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" pass + + def is_logged_batch(self) -> bool: + if self.trainer is None: + return False + else: + return self.trainer._logger_connector.should_update_logs + + def add_on_first_batch(self, metrics: dict, output: dict, batch_idx: int) -> None: + if batch_idx == 0: + output.update(metrics) diff --git a/text_recognizer/model/transformer.py b/text_recognizer/model/transformer.py index 23b2a3a..ae6947c 100644 --- a/text_recognizer/model/transformer.py +++ b/text_recognizer/model/transformer.py @@ -1,12 +1,12 @@ """Lightning model for transformer networks.""" -from typing import Callable, Optional, Sequence, Tuple, Type -from text_recognizer.model.greedy_decoder import GreedyDecoder +from typing import Callable, Optional, Tuple, Type import torch from omegaconf import DictConfig from torch import nn, Tensor from torchmetrics import CharErrorRate, WordErrorRate +from .greedy_decoder import GreedyDecoder from text_recognizer.data.tokenizer import Tokenizer from text_recognizer.model.base import LitBase @@ -42,47 +42,66 @@ class LitTransformer(LitBase): def teacher_forward(self, data: Tensor, targets: Tensor) -> Tensor: """Non-autoregressive forward pass.""" - return self.network(data, targets) + logits = self.network(data, targets) # [B, N, C] + return logits.permute(0, 2, 1) # [B, C, N] def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" data, targets = batch logits = self.teacher_forward(data, targets[:, :-1]) loss = self.loss_fn(logits, targets[:, 1:]) + self.log("train/loss", loss, prog_bar=True) + + outputs = {"loss": loss} + + if self.is_logged_batch(): + preds, gts = self.tokenizer.decode_logits( + logits + ), self.tokenizer.batch_decode(targets) + outputs.update({"predictions": preds, "ground_truths": gts}) + return loss - def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> dict: """Validation step.""" data, targets = batch preds = self(data) - pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets) + preds, gts = self.tokenizer.batch_decode(preds), self.tokenizer.batch_decode( + targets + ) + + self.val_cer(preds, gts) + self.val_wer(preds, gts) - self.val_acc(preds, targets) - self.val_cer(pred_text, target_text) - self.val_wer(pred_text, target_text) - self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True) - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + outputs = {} + self.add_on_first_batch( + {"predictions": preds, "ground_truths": gts}, outputs, batch_idx + ) + return outputs + + def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> dict: """Test step.""" data, targets = batch preds = self(data) - pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets) + preds, gts = self.tokenizer.batch_decode(preds), self.tokenizer.batch_decode( + targets + ) + + self.test_cer(preds, gts) + self.test_wer(preds, gts) - self.test_acc(preds, targets) - self.test_cer(pred_text, target_text) - self.test_wer(pred_text, target_text) - self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True) - def _to_tokens( - self, - indices: Tensor, - ) -> Sequence[str]: - return [self.tokenizer.decode(i) for i in indices] + outputs = {} + self.add_on_first_batch( + {"predictions": preds, "ground_truths": gts}, outputs, batch_idx + ) + return outputs @torch.no_grad() def predict(self, x: Tensor) -> Tensor: -- cgit v1.2.3-70-g09d2