summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/metrics.py
blob: ffad7920ce152abe7adc9ba572e5ae66e93cb764 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""Utility functions for models."""
import Levenshtein as Lev
import torch
from torch import Tensor

from text_recognizer.networks import greedy_decoder


def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float:
    """Computes the accuracy.

    Args:
        outputs (Tensor): The output from the network.
        labels (Tensor): Ground truth labels.
        pad_index (int): Padding index.

    Returns:
        float: The accuracy for the batch.

    """

    _, predicted = torch.max(outputs, dim=-1)

    # Mask out the pad tokens
    mask = labels != pad_index

    predicted *= mask
    labels *= mask

    acc = (predicted == labels).sum().float() / labels.shape[0]
    acc = acc.item()
    return acc


def cer(outputs: Tensor, targets: Tensor) -> float:
    """Computes the character error rate.

    Args:
        outputs (Tensor): The output from the network.
        targets (Tensor): Ground truth labels.

    Returns:
        float: The cer for the batch.

    """
    target_lengths = torch.full(
        size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
    )
    decoded_predictions, decoded_targets = greedy_decoder(
        outputs, targets, target_lengths
    )

    lev_dist = 0

    for prediction, target in zip(decoded_predictions, decoded_targets):
        prediction = "".join(prediction)
        target = "".join(target)
        prediction, target = (
            prediction.replace(" ", ""),
            target.replace(" ", ""),
        )
        lev_dist += Lev.distance(prediction, target)
    return lev_dist / len(decoded_predictions)


def wer(outputs: Tensor, targets: Tensor) -> float:
    """Computes the Word error rate.

    Args:
        outputs (Tensor): The output from the network.
        targets (Tensor): Ground truth labels.

    Returns:
        float: The wer for the batch.

    """
    target_lengths = torch.full(
        size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
    )
    decoded_predictions, decoded_targets = greedy_decoder(
        outputs, targets, target_lengths
    )

    lev_dist = 0

    for prediction, target in zip(decoded_predictions, decoded_targets):
        prediction = "".join(prediction)
        target = "".join(target)

        b = set(prediction.split() + target.split())
        word2char = dict(zip(b, range(len(b))))

        w1 = [chr(word2char[w]) for w in prediction.split()]
        w2 = [chr(word2char[w]) for w in target.split()]

        lev_dist += Lev.distance("".join(w1), "".join(w2))

    return lev_dist / len(decoded_predictions)