diff options
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r-- | text_recognizer/models/transformer.py | 28 |
1 files changed, 20 insertions, 8 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index dcec756..2c74b7e 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -3,11 +3,12 @@ from typing import Optional, Tuple, Type import torch from omegaconf import DictConfig -from torch import Tensor, nn +from torch import nn, Tensor from text_recognizer.data.mappings import EmnistMapping from text_recognizer.models.base import LitBase -from text_recognizer.models.metrics import CharacterErrorRate +from text_recognizer.models.metrics.cer import CharacterErrorRate +from text_recognizer.models.metrics.wer import WordErrorRate class LitTransformer(LitBase): @@ -18,16 +19,13 @@ class LitTransformer(LitBase): network: Type[nn.Module], loss_fn: Type[nn.Module], optimizer_config: DictConfig, - lr_scheduler_config: Optional[DictConfig], mapping: EmnistMapping, + lr_scheduler_config: Optional[DictConfig] = None, max_output_len: int = 682, start_token: str = "<s>", end_token: str = "<e>", pad_token: str = "<p>", ) -> None: - super().__init__( - network, loss_fn, optimizer_config, lr_scheduler_config, mapping - ) self.max_output_len = max_output_len self.start_token = start_token self.end_token = end_token @@ -38,6 +36,16 @@ class LitTransformer(LitBase): self.ignore_indices = set([self.start_index, self.end_index, self.pad_index]) self.val_cer = CharacterErrorRate(self.ignore_indices) self.test_cer = CharacterErrorRate(self.ignore_indices) + self.val_wer = WordErrorRate(self.ignore_indices) + self.test_wer = WordErrorRate(self.ignore_indices) + super().__init__( + network, + loss_fn, + optimizer_config, + lr_scheduler_config, + mapping, + self.pad_index, + ) def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" @@ -59,6 +67,8 @@ class LitTransformer(LitBase): self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) self.val_cer(preds, targets) self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) + self.val_wer(preds, targets) + self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" @@ -66,10 +76,12 @@ class LitTransformer(LitBase): # Compute the text prediction. pred = self(data) - self.test_cer(pred, targets) - self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) self.test_acc(pred, targets) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) + self.test_cer(pred, targets) + self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) + self.test_wer(pred, targets) + self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True) @torch.no_grad() def predict(self, x: Tensor) -> Tensor: |