diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 00:10:26 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 00:10:26 +0200 |
commit | 0540237d794ab2071764dc74e4d3bb52f5bf44be (patch) | |
tree | dad3469f843da16716871d0b9805bf0301aa6cfe /text_recognizer/models/metrics.py | |
parent | bf680dce6bc7dcadd20923a193fc9ab8fbd0a0c6 (diff) |
Update metrics
Diffstat (limited to 'text_recognizer/models/metrics.py')
-rw-r--r-- | text_recognizer/models/metrics.py | 36 |
1 files changed, 0 insertions, 36 deletions
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py deleted file mode 100644 index 3cb16b5..0000000 --- a/text_recognizer/models/metrics.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Character Error Rate (CER).""" -from typing import Sequence - -import editdistance -import torch -from torch import Tensor -from torchmetrics import Metric - - -class CharacterErrorRate(Metric): - """Character error rate metric, computed using Levenshtein distance.""" - - def __init__(self, ignore_indices: Sequence[Tensor]) -> None: - super().__init__() - self.ignore_indices = set(ignore_indices) - self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - self.error: Tensor - self.total: Tensor - - def update(self, preds: Tensor, targets: Tensor) -> None: - """Update CER.""" - bsz = preds.shape[0] - for index in range(bsz): - pred = [p for p in preds[index].tolist() if p not in self.ignore_indices] - target = [ - t for t in targets[index].tolist() if t not in self.ignore_indices - ] - distance = editdistance.distance(pred, target) - error = distance / max(len(pred), len(target)) - self.error += error - self.total += bsz - - def compute(self) -> Tensor: - """Compute CER.""" - return self.error / self.total |