summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/metrics.py
blob: 42c3c6e7ba4359eab0189dfc287d44c83e8198ff (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Utility functions for models."""
import Levenshtein as Lev
import torch
from torch import Tensor

from text_recognizer.networks import greedy_decoder


def accuracy(outputs: Tensor, labels: Tensor) -> float:
    """Computes the accuracy.

    Args:
        outputs (Tensor): The output from the network.
        labels (Tensor): Ground truth labels.

    Returns:
        float: The accuracy for the batch.

    """
    # eos_index = torch.nonzero(labels == eos, as_tuple=False)
    # eos_index = eos_index[0].item() if eos_index.nelement() else -1

    _, predicted = torch.max(outputs, dim=-1)
    acc = (predicted == labels).sum().float() / labels.shape[0]
    acc = acc.item()
    return acc


def cer(outputs: Tensor, targets: Tensor) -> float:
    """Computes the character error rate.

    Args:
        outputs (Tensor): The output from the network.
        targets (Tensor): Ground truth labels.

    Returns:
        float: The cer for the batch.

    """
    target_lengths = torch.full(
        size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
    )
    decoded_predictions, decoded_targets = greedy_decoder(
        outputs, targets, target_lengths
    )

    lev_dist = 0

    for prediction, target in zip(decoded_predictions, decoded_targets):
        prediction = "".join(prediction)
        target = "".join(target)
        prediction, target = (
            prediction.replace(" ", ""),
            target.replace(" ", ""),
        )
        lev_dist += Lev.distance(prediction, target)
    return lev_dist / len(decoded_predictions)


def wer(outputs: Tensor, targets: Tensor) -> float:
    """Computes the Word error rate.

    Args:
        outputs (Tensor): The output from the network.
        targets (Tensor): Ground truth labels.

    Returns:
        float: The wer for the batch.

    """
    target_lengths = torch.full(
        size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
    )
    decoded_predictions, decoded_targets = greedy_decoder(
        outputs, targets, target_lengths
    )

    lev_dist = 0

    for prediction, target in zip(decoded_predictions, decoded_targets):
        prediction = "".join(prediction)
        target = "".join(target)

        b = set(prediction.split() + target.split())
        word2char = dict(zip(b, range(len(b))))

        w1 = [chr(word2char[w]) for w in prediction.split()]
        w2 = [chr(word2char[w]) for w in target.split()]

        lev_dist += Lev.distance("".join(w1), "".join(w2))

    return lev_dist / len(decoded_predictions)