From 31e9673eef3088f08e3ee6aef8b78abd701ca329 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 4 Apr 2021 16:05:13 +0200 Subject: Reformat test for CER --- text_recognizer/models/base.py | 30 +++++++++++++++++++++--------- text_recognizer/models/metrics.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 9 deletions(-) create mode 100644 text_recognizer/models/metrics.py (limited to 'text_recognizer/models') diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index e86b478..0c70625 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -6,6 +6,7 @@ import pytorch_lightning as pl import torch from torch import nn from torch import Tensor +import torchmetrics from text_recognizer import networks @@ -13,7 +14,14 @@ from text_recognizer import networks class BaseModel(pl.LightningModule): """Abstract PyTorch Lightning class.""" - def __init__(self, network_args: Dict, optimizer_args: Dict, lr_scheduler_args: Dict, criterion_args: Dict, monitor: str = "val_loss") -> None: + def __init__( + self, + network_args: Dict, + optimizer_args: Dict, + lr_scheduler_args: Dict, + criterion_args: Dict, + monitor: str = "val_loss", + ) -> None: super().__init__() self.monitor = monitor self.network = getattr(networks, network_args["type"])(**network_args["args"]) @@ -22,9 +30,9 @@ class BaseModel(pl.LightningModule): self.loss_fn = self.configure_criterion(criterion_args) # Accuracy metric - self.train_acc = pl.metrics.Accuracy() - self.val_acc = pl.metrics.Accuracy() - self.test_acc = pl.metrics.Accuracy() + self.train_acc = torchmetrics.Accuracy() + self.val_acc = torchmetrics.Accuracy() + self.test_acc = torchmetrics.Accuracy() @staticmethod def configure_criterion(criterion_args: Dict) -> Type[nn.Module]: @@ -41,8 +49,14 @@ class BaseModel(pl.LightningModule): optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args) args = {} or self.lr_scheduler_args["args"] - scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])(**args) - return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": self.monitor} + scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])( + **args + ) + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + "monitor": self.monitor, + } def forward(self, data: Tensor) -> Tensor: """Feedforward pass.""" @@ -55,7 +69,7 @@ class BaseModel(pl.LightningModule): loss = self.loss_fn(logits, targets) self.log("train_loss", loss) self.train_acc(logits, targets) - self.log("train_acc": self.train_acc, on_step=False, on_epoch=True) + self.log("train_acc", self.train_acc, on_step=False, on_epoch=True) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: @@ -73,5 +87,3 @@ class BaseModel(pl.LightningModule): logits = self(data) self.test_acc(logits, targets) self.log("test_acc", self.test_acc, on_step=False, on_epoch=True) - - diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py new file mode 100644 index 0000000..58d0537 --- /dev/null +++ b/text_recognizer/models/metrics.py @@ -0,0 +1,32 @@ +"""Character Error Rate (CER).""" +from typing import Sequence + +import editdistance +import torch +from torch import Tensor +import torchmetrics + + +class CharacterErrorRate(torchmetrics.Metric): + """Character error rate metric, computed using Levenshtein distance.""" + + def __init__(self, ignore_tokens: Sequence[int], *args) -> None: + super().__init__() + self.ignore_tokens = set(ignore_tokens) + self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, targets: Tensor) -> None: + """Update CER.""" + bsz = preds.shape[0] + for index in range(bsz): + pred = [p for p in preds[index].tolist() if p not in self.ignore_tokens] + target = [t for t in targets[index].tolist() if t not in self.ignore_tokens] + distance = editdistance.distance(pred, target) + error = distance / max(len(pred), len(target)) + self.error += error + self.total += bsz + + def compute(self) -> Tensor: + """Compute CER.""" + return self.error / self.total -- cgit v1.2.3-70-g09d2