diff options
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/base.py | 9 | ||||
-rw-r--r-- | text_recognizer/models/metrics.py | 36 | ||||
-rw-r--r-- | text_recognizer/models/metrics/__init__.py | 0 | ||||
-rw-r--r-- | text_recognizer/models/metrics/cer.py | 23 | ||||
-rw-r--r-- | text_recognizer/models/metrics/wer.py | 23 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 28 |
6 files changed, 71 insertions, 48 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index f917635..bb4e695 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -6,7 +6,7 @@ import torch from loguru import logger as log from omegaconf import DictConfig from pytorch_lightning import LightningModule -from torch import Tensor, nn +from torch import nn, Tensor from torchmetrics import Accuracy from text_recognizer.data.mappings import EmnistMapping @@ -22,6 +22,7 @@ class LitBase(LightningModule): optimizer_config: DictConfig, lr_scheduler_config: Optional[DictConfig], mapping: EmnistMapping, + ignore_index: Optional[int] = None, ) -> None: super().__init__() @@ -32,9 +33,9 @@ class LitBase(LightningModule): self.mapping = mapping # Placeholders - self.train_acc = Accuracy(mdmc_reduce="samplewise") - self.val_acc = Accuracy(mdmc_reduce="samplewise") - self.test_acc = Accuracy(mdmc_reduce="samplewise") + self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) + self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) + self.test_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) def optimizer_zero_grad( self, diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py deleted file mode 100644 index 3cb16b5..0000000 --- a/text_recognizer/models/metrics.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Character Error Rate (CER).""" -from typing import Sequence - -import editdistance -import torch -from torch import Tensor -from torchmetrics import Metric - - -class CharacterErrorRate(Metric): - """Character error rate metric, computed using Levenshtein distance.""" - - def __init__(self, ignore_indices: Sequence[Tensor]) -> None: - super().__init__() - self.ignore_indices = set(ignore_indices) - self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - self.error: Tensor - self.total: Tensor - - def update(self, preds: Tensor, targets: Tensor) -> None: - """Update CER.""" - bsz = preds.shape[0] - for index in range(bsz): - pred = [p for p in preds[index].tolist() if p not in self.ignore_indices] - target = [ - t for t in targets[index].tolist() if t not in self.ignore_indices - ] - distance = editdistance.distance(pred, target) - error = distance / max(len(pred), len(target)) - self.error += error - self.total += bsz - - def compute(self) -> Tensor: - """Compute CER.""" - return self.error / self.total diff --git a/text_recognizer/models/metrics/__init__.py b/text_recognizer/models/metrics/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/text_recognizer/models/metrics/__init__.py diff --git a/text_recognizer/models/metrics/cer.py b/text_recognizer/models/metrics/cer.py new file mode 100644 index 0000000..238ecc3 --- /dev/null +++ b/text_recognizer/models/metrics/cer.py @@ -0,0 +1,23 @@ +"""Character Error Rate (CER).""" +from typing import Sequence + +import torch +import torchmetrics + + +class CharacterErrorRate(torchmetrics.CharErrorRate): + """Character error rate metric, allowing for tokens to be ignored.""" + + def __init__(self, ignore_tokens: Sequence[int], *args): + super().__init__(*args) + self.ignore_tokens = set(ignore_tokens) + + def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: + preds_l = [ + [t for t in pred if t not in self.ignore_tokens] for pred in preds.tolist() + ] + targets_l = [ + [t for t in target if t not in self.ignore_tokens] + for target in targets.tolist() + ] + super().update(preds_l, targets_l) diff --git a/text_recognizer/models/metrics/wer.py b/text_recognizer/models/metrics/wer.py new file mode 100644 index 0000000..78f5854 --- /dev/null +++ b/text_recognizer/models/metrics/wer.py @@ -0,0 +1,23 @@ +"""Character Error Rate (CER).""" +from typing import Sequence + +import torch +import torchmetrics + + +class WordErrorRate(torchmetrics.WordErrorRate): + """Character error rate metric, allowing for tokens to be ignored.""" + + def __init__(self, ignore_tokens: Sequence[int], *args): + super().__init__(*args) + self.ignore_tokens = set(ignore_tokens) + + def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: + preds_l = [ + [t for t in pred if t not in self.ignore_tokens] for pred in preds.tolist() + ] + targets_l = [ + [t for t in target if t not in self.ignore_tokens] + for target in targets.tolist() + ] + super().update(preds_l, targets_l) diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index dcec756..2c74b7e 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -3,11 +3,12 @@ from typing import Optional, Tuple, Type import torch from omegaconf import DictConfig -from torch import Tensor, nn +from torch import nn, Tensor from text_recognizer.data.mappings import EmnistMapping from text_recognizer.models.base import LitBase -from text_recognizer.models.metrics import CharacterErrorRate +from text_recognizer.models.metrics.cer import CharacterErrorRate +from text_recognizer.models.metrics.wer import WordErrorRate class LitTransformer(LitBase): @@ -18,16 +19,13 @@ class LitTransformer(LitBase): network: Type[nn.Module], loss_fn: Type[nn.Module], optimizer_config: DictConfig, - lr_scheduler_config: Optional[DictConfig], mapping: EmnistMapping, + lr_scheduler_config: Optional[DictConfig] = None, max_output_len: int = 682, start_token: str = "<s>", end_token: str = "<e>", pad_token: str = "<p>", ) -> None: - super().__init__( - network, loss_fn, optimizer_config, lr_scheduler_config, mapping - ) self.max_output_len = max_output_len self.start_token = start_token self.end_token = end_token @@ -38,6 +36,16 @@ class LitTransformer(LitBase): self.ignore_indices = set([self.start_index, self.end_index, self.pad_index]) self.val_cer = CharacterErrorRate(self.ignore_indices) self.test_cer = CharacterErrorRate(self.ignore_indices) + self.val_wer = WordErrorRate(self.ignore_indices) + self.test_wer = WordErrorRate(self.ignore_indices) + super().__init__( + network, + loss_fn, + optimizer_config, + lr_scheduler_config, + mapping, + self.pad_index, + ) def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" @@ -59,6 +67,8 @@ class LitTransformer(LitBase): self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) self.val_cer(preds, targets) self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) + self.val_wer(preds, targets) + self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" @@ -66,10 +76,12 @@ class LitTransformer(LitBase): # Compute the text prediction. pred = self(data) - self.test_cer(pred, targets) - self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) self.test_acc(pred, targets) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) + self.test_cer(pred, targets) + self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) + self.test_wer(pred, targets) + self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True) @torch.no_grad() def predict(self, x: Tensor) -> Tensor: |