diff options
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r-- | text_recognizer/models/transformer.py | 26 |
1 files changed, 5 insertions, 21 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index bc7e313..6be0ac5 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -2,9 +2,7 @@ from typing import Dict, List, Optional, Union, Tuple, Type from omegaconf import DictConfig -from torch import nn -from torch import Tensor -import wandb +from torch import nn, Tensor from text_recognizer.data.emnist import emnist_mapping from text_recognizer.models.metrics import CharacterErrorRate @@ -44,24 +42,12 @@ class LitTransformerModel(LitBaseModel): # TODO: add case for sentence pieces return mapping, ignore_tokens - def _log_prediction(self, data: Tensor, pred: Tensor) -> None: - """Logs prediction on image with wandb.""" - pred_str = "".join( - self.mapping[i] for i in pred[0].tolist() if i != 3 - ) # pad index is 3 - try: - self.logger.experiment.log( - {"val_pred_examples": [wandb.Image(data[0], caption=pred_str)]} - ) - except AttributeError: - pass - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" data, targets = batch logits = self.network(data, targets[:, :-1]) loss = self.loss_fn(logits, targets[:, 1:]) - self.log("train_loss", loss) + self.log("train/loss", loss) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: @@ -70,17 +56,15 @@ class LitTransformerModel(LitBaseModel): logits = self.network(data, targets[:-1]) loss = self.loss_fn(logits, targets[1:]) - self.log("val_loss", loss, prog_bar=True) + self.log("val/loss", loss, prog_bar=True) pred = self.network.predict(data) - self._log_prediction(data, pred) self.val_cer(pred, targets) - self.log("val_cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) + self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" data, targets = batch pred = self.network.predict(data) - self._log_prediction(data, pred) self.test_cer(pred, targets) - self.log("test_cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) + self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) |