summaryrefslogtreecommitdiff
path: root/text_recognizer/models/metrics.py
blob: f83c9e4ee1b83d90abf96b9727652c7160e5d181 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""Character Error Rate (CER)."""
from typing import Set

import attr
import editdistance
import torch
from torch import Tensor
from torchmetrics import Metric


@attr.s(eq=False)
class CharacterErrorRate(Metric):
    """Character error rate metric, computed using Levenshtein distance."""

    ignore_indices: Set[Tensor] = attr.ib(converter=set)
    error: Tensor = attr.ib(init=False)
    total: Tensor = attr.ib(init=False)

    def __attrs_post_init__(self) -> None:
        super().__init__()
        self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: Tensor, targets: Tensor) -> None:
        """Update CER."""
        bsz = preds.shape[0]
        for index in range(bsz):
            pred = [p for p in preds[index].tolist() if p not in self.ignore_indices]
            target = [
                t for t in targets[index].tolist() if t not in self.ignore_indices
            ]
            distance = editdistance.distance(pred, target)
            error = distance / max(len(pred), len(target))
            self.error += error
        self.total += bsz

    def compute(self) -> Tensor:
        """Compute CER."""
        return self.error / self.total