diff options
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r-- | text_recognizer/models/transformer.py | 25 |
1 files changed, 12 insertions, 13 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index ff4d08d..b2e5d5f 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -52,12 +52,12 @@ class LitTransformer(LitBase): """Validation step.""" data, targets = batch preds = self.predict(data) - pred_text, target_text = self.get_text(preds, targets) - self.val_acc(pred_text, target_text) - self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) + pred_text, target_text = self._get_text(preds), self._get_text(targets) + self.val_acc(preds, targets) self.val_cer(pred_text, target_text) - self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) self.val_wer(pred_text, target_text) + self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) + self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: @@ -66,20 +66,19 @@ class LitTransformer(LitBase): # Compute the text prediction. preds = self(data) - pred_text, target_text = self.get_text(preds, targets) - self.test_acc(pred_text, target_text) - self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) + pred_text, target_text = self._get_text(preds), self._get_text(targets) + self.test_acc(preds, targets) self.test_cer(pred_text, target_text) - self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) self.test_wer(pred_text, target_text) + self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) + self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True) - def get_text( - self, preds: Tensor, targets: Tensor + def _get_text( + self, + xs: Tensor, ) -> Tuple[Sequence[str], Sequence[str]]: - pred_text = [self.tokenizer.decode(p) for p in preds] - target_text = [self.tokenizer.decode(t) for t in targets] - return pred_text, target_text + return [self.tokenizer.decode(x) for x in xs] @torch.no_grad() def predict(self, x: Tensor) -> Tensor: |