summaryrefslogtreecommitdiff
path: root/text_recognizer/models/transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r--text_recognizer/models/transformer.py26
1 files changed, 5 insertions, 21 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index bc7e313..6be0ac5 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -2,9 +2,7 @@
from typing import Dict, List, Optional, Union, Tuple, Type
from omegaconf import DictConfig
-from torch import nn
-from torch import Tensor
-import wandb
+from torch import nn, Tensor
from text_recognizer.data.emnist import emnist_mapping
from text_recognizer.models.metrics import CharacterErrorRate
@@ -44,24 +42,12 @@ class LitTransformerModel(LitBaseModel):
# TODO: add case for sentence pieces
return mapping, ignore_tokens
- def _log_prediction(self, data: Tensor, pred: Tensor) -> None:
- """Logs prediction on image with wandb."""
- pred_str = "".join(
- self.mapping[i] for i in pred[0].tolist() if i != 3
- ) # pad index is 3
- try:
- self.logger.experiment.log(
- {"val_pred_examples": [wandb.Image(data[0], caption=pred_str)]}
- )
- except AttributeError:
- pass
-
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
data, targets = batch
logits = self.network(data, targets[:, :-1])
loss = self.loss_fn(logits, targets[:, 1:])
- self.log("train_loss", loss)
+ self.log("train/loss", loss)
return loss
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
@@ -70,17 +56,15 @@ class LitTransformerModel(LitBaseModel):
logits = self.network(data, targets[:-1])
loss = self.loss_fn(logits, targets[1:])
- self.log("val_loss", loss, prog_bar=True)
+ self.log("val/loss", loss, prog_bar=True)
pred = self.network.predict(data)
- self._log_prediction(data, pred)
self.val_cer(pred, targets)
- self.log("val_cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
data, targets = batch
pred = self.network.predict(data)
- self._log_prediction(data, pred)
self.test_cer(pred, targets)
- self.log("test_cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)