diff options
Diffstat (limited to 'text_recognizer/models/metrics.py')
-rw-r--r-- | text_recognizer/models/metrics.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index 4117ae2..9793157 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -1,5 +1,5 @@ """Character Error Rate (CER).""" -from typing import Set, Sequence +from typing import Set import attr import editdistance @@ -12,7 +12,7 @@ from torchmetrics import Metric class CharacterErrorRate(Metric): """Character error rate metric, computed using Levenshtein distance.""" - ignore_tokens: Set = attr.ib(converter=set) + ignore_indices: Set = attr.ib(converter=set) error: Tensor = attr.ib(init=False) total: Tensor = attr.ib(init=False) @@ -25,8 +25,8 @@ class CharacterErrorRate(Metric): """Update CER.""" bsz = preds.shape[0] for index in range(bsz): - pred = [p for p in preds[index].tolist() if p not in self.ignore_tokens] - target = [t for t in targets[index].tolist() if t not in self.ignore_tokens] + pred = [p for p in preds[index].tolist() if p not in self.ignore_indices] + target = [t for t in targets[index].tolist() if t not in self.ignore_indices] distance = editdistance.distance(pred, target) error = distance / max(len(pred), len(target)) self.error += error |