summaryrefslogtreecommitdiff
path: root/tests/test_cer.py
blob: 6b7565e0eb2d9a2afc72ac3b1716970c3ff81500 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
"""Test the CER metric."""
import torch

from text_recognizer.models.metrics import CharacterErrorRate


def test_character_error_rate() -> None:
    """Test CER computation."""
    metric = CharacterErrorRate([0, 1])
    preds = torch.Tensor(
        [
            [0, 2, 2, 3, 3, 1],  # error will be 0
            [0, 2, 1, 1, 1, 1],  # error will be 0.75
            [0, 2, 2, 4, 4, 1],  # error will be 0.5
        ]
    )

    targets = torch.Tensor([[0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1]])
    metric(preds, targets)
    print(metric.compute())
    assert metric.compute() == float(sum([0, 0.75, 0.5]) / 3)