summaryrefslogtreecommitdiff
path: root/text_recognizer/models/metrics.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 16:05:13 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 16:05:13 +0200
commit31e9673eef3088f08e3ee6aef8b78abd701ca329 (patch)
treef529d975d18d718a5d646e93f746d8be6f2f5cfe /text_recognizer/models/metrics.py
parent36964354407d0fdf73bdca2f611fee1664860197 (diff)
Reformat test for CER
Diffstat (limited to 'text_recognizer/models/metrics.py')
-rw-r--r--text_recognizer/models/metrics.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
new file mode 100644
index 0000000..58d0537
--- /dev/null
+++ b/text_recognizer/models/metrics.py
@@ -0,0 +1,32 @@
+"""Character Error Rate (CER)."""
+from typing import Sequence
+
+import editdistance
+import torch
+from torch import Tensor
+import torchmetrics
+
+
+class CharacterErrorRate(torchmetrics.Metric):
+ """Character error rate metric, computed using Levenshtein distance."""
+
+ def __init__(self, ignore_tokens: Sequence[int], *args) -> None:
+ super().__init__()
+ self.ignore_tokens = set(ignore_tokens)
+ self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum")
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
+
+ def update(self, preds: Tensor, targets: Tensor) -> None:
+ """Update CER."""
+ bsz = preds.shape[0]
+ for index in range(bsz):
+ pred = [p for p in preds[index].tolist() if p not in self.ignore_tokens]
+ target = [t for t in targets[index].tolist() if t not in self.ignore_tokens]
+ distance = editdistance.distance(pred, target)
+ error = distance / max(len(pred), len(target))
+ self.error += error
+ self.total += bsz
+
+ def compute(self) -> Tensor:
+ """Compute CER."""
+ return self.error / self.total