From 7268035fb9e57342612a8cc50a1fe04e8841ca2f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 30 Jul 2021 23:15:03 +0200 Subject: attr bug fix, properly loading network --- text_recognizer/models/metrics.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'text_recognizer/models/metrics.py') diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index 4117ae2..9793157 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -1,5 +1,5 @@ """Character Error Rate (CER).""" -from typing import Set, Sequence +from typing import Set import attr import editdistance @@ -12,7 +12,7 @@ from torchmetrics import Metric class CharacterErrorRate(Metric): """Character error rate metric, computed using Levenshtein distance.""" - ignore_tokens: Set = attr.ib(converter=set) + ignore_indices: Set = attr.ib(converter=set) error: Tensor = attr.ib(init=False) total: Tensor = attr.ib(init=False) @@ -25,8 +25,8 @@ class CharacterErrorRate(Metric): """Update CER.""" bsz = preds.shape[0] for index in range(bsz): - pred = [p for p in preds[index].tolist() if p not in self.ignore_tokens] - target = [t for t in targets[index].tolist() if t not in self.ignore_tokens] + pred = [p for p in preds[index].tolist() if p not in self.ignore_indices] + target = [t for t in targets[index].tolist() if t not in self.ignore_indices] distance = editdistance.distance(pred, target) error = distance / max(len(pred), len(target)) self.error += error -- cgit v1.2.3-70-g09d2