diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-30 23:15:03 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-30 23:15:03 +0200 |
commit | 7268035fb9e57342612a8cc50a1fe04e8841ca2f (patch) | |
tree | 8d4cf3743975bd25f2c04d6a56ff3d4608a7e8d9 /text_recognizer/models/metrics.py | |
parent | 92fc1c7ed2f9f64552be8f71d9b8ab0d5a0a88d4 (diff) |
attr bug fix, properly loading network
Diffstat (limited to 'text_recognizer/models/metrics.py')
-rw-r--r-- | text_recognizer/models/metrics.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index 4117ae2..9793157 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -1,5 +1,5 @@ """Character Error Rate (CER).""" -from typing import Set, Sequence +from typing import Set import attr import editdistance @@ -12,7 +12,7 @@ from torchmetrics import Metric class CharacterErrorRate(Metric): """Character error rate metric, computed using Levenshtein distance.""" - ignore_tokens: Set = attr.ib(converter=set) + ignore_indices: Set = attr.ib(converter=set) error: Tensor = attr.ib(init=False) total: Tensor = attr.ib(init=False) @@ -25,8 +25,8 @@ class CharacterErrorRate(Metric): """Update CER.""" bsz = preds.shape[0] for index in range(bsz): - pred = [p for p in preds[index].tolist() if p not in self.ignore_tokens] - target = [t for t in targets[index].tolist() if t not in self.ignore_tokens] + pred = [p for p in preds[index].tolist() if p not in self.ignore_indices] + target = [t for t in targets[index].tolist() if t not in self.ignore_indices] distance = editdistance.distance(pred, target) error = distance / max(len(pred), len(target)) self.error += error |