From 923e18913a6518d6278d1fc1843a01bacf955c60 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Wed, 28 Sep 2022 23:43:56 +0200
Subject: Update metrics

---
 text_recognizer/models/transformer.py | 25 ++++++++++++-------------
 1 file changed, 12 insertions(+), 13 deletions(-)

(limited to 'text_recognizer')

diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index ff4d08d..b2e5d5f 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -52,12 +52,12 @@ class LitTransformer(LitBase):
         """Validation step."""
         data, targets = batch
         preds = self.predict(data)
-        pred_text, target_text = self.get_text(preds, targets)
-        self.val_acc(pred_text, target_text)
-        self.log("val/acc", self.val_acc, on_step=False, on_epoch=True)
+        pred_text, target_text = self._get_text(preds), self._get_text(targets)
+        self.val_acc(preds, targets)
         self.val_cer(pred_text, target_text)
-        self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
         self.val_wer(pred_text, target_text)
+        self.log("val/acc", self.val_acc, on_step=False, on_epoch=True)
+        self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
         self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True)
 
     def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
@@ -66,20 +66,19 @@ class LitTransformer(LitBase):
 
         # Compute the text prediction.
         preds = self(data)
-        pred_text, target_text = self.get_text(preds, targets)
-        self.test_acc(pred_text, target_text)
-        self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
+        pred_text, target_text = self._get_text(preds), self._get_text(targets)
+        self.test_acc(preds, targets)
         self.test_cer(pred_text, target_text)
-        self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
         self.test_wer(pred_text, target_text)
+        self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
+        self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
         self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True)
 
-    def get_text(
-        self, preds: Tensor, targets: Tensor
+    def _get_text(
+        self,
+        xs: Tensor,
     ) -> Tuple[Sequence[str], Sequence[str]]:
-        pred_text = [self.tokenizer.decode(p) for p in preds]
-        target_text = [self.tokenizer.decode(t) for t in targets]
-        return pred_text, target_text
+        return [self.tokenizer.decode(x) for x in xs]
 
     @torch.no_grad()
     def predict(self, x: Tensor) -> Tensor:
-- 
cgit v1.2.3-70-g09d2