From 31e9673eef3088f08e3ee6aef8b78abd701ca329 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 4 Apr 2021 16:05:13 +0200 Subject: Reformat test for CER --- text_recognizer/models/metrics.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 text_recognizer/models/metrics.py (limited to 'text_recognizer/models/metrics.py') diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py new file mode 100644 index 0000000..58d0537 --- /dev/null +++ b/text_recognizer/models/metrics.py @@ -0,0 +1,32 @@ +"""Character Error Rate (CER).""" +from typing import Sequence + +import editdistance +import torch +from torch import Tensor +import torchmetrics + + +class CharacterErrorRate(torchmetrics.Metric): + """Character error rate metric, computed using Levenshtein distance.""" + + def __init__(self, ignore_tokens: Sequence[int], *args) -> None: + super().__init__() + self.ignore_tokens = set(ignore_tokens) + self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, targets: Tensor) -> None: + """Update CER.""" + bsz = preds.shape[0] + for index in range(bsz): + pred = [p for p in preds[index].tolist() if p not in self.ignore_tokens] + target = [t for t in targets[index].tolist() if t not in self.ignore_tokens] + distance = editdistance.distance(pred, target) + error = distance / max(len(pred), len(target)) + self.error += error + self.total += bsz + + def compute(self) -> Tensor: + """Compute CER.""" + return self.error / self.total -- cgit v1.2.3-70-g09d2