summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/metrics.py
blob: 2605731b85ce7ec28c24695bf089a751b3e5df4d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Utility functions for models."""
from typing import Optional

from einops import rearrange
import Levenshtein as Lev
import torch
from torch import Tensor

from text_recognizer.networks import greedy_decoder


def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float:
    """Computes the accuracy.

    Args:
        outputs (Tensor): The output from the network.
        labels (Tensor): Ground truth labels.
        pad_index (int): Padding index.

    Returns:
        float: The accuracy for the batch.

    """

    _, predicted = torch.max(outputs, dim=-1)

    # Mask out the pad tokens
    mask = labels != pad_index

    predicted *= mask
    labels *= mask

    acc = (predicted == labels).sum().float() / labels.shape[0]
    acc = acc.item()
    return acc


def cer(
    outputs: Tensor,
    targets: Tensor,
    batch_size: Optional[int] = None,
    blank_label: Optional[int] = int,
) -> float:
    """Computes the character error rate.

    Args:
        outputs (Tensor): The output from the network.
        targets (Tensor): Ground truth labels.
        batch_size (Optional[int]): Batch size if target and output has been flattend.
        blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.

    Returns:
        float: The cer for the batch.

    """
    if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
        targets = rearrange(targets, "(b t) -> b t", b=batch_size)
        outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)

    target_lengths = torch.full(
        size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
    )
    decoded_predictions, decoded_targets = greedy_decoder(
        outputs, targets, target_lengths, blank_label=blank_label,
    )

    lev_dist = 0

    for prediction, target in zip(decoded_predictions, decoded_targets):
        prediction = "".join(prediction)
        target = "".join(target)
        prediction, target = (
            prediction.replace(" ", ""),
            target.replace(" ", ""),
        )
        lev_dist += Lev.distance(prediction, target)
    return lev_dist / len(decoded_predictions)


def wer(
    outputs: Tensor,
    targets: Tensor,
    batch_size: Optional[int] = None,
    blank_label: Optional[int] = int,
) -> float:
    """Computes the Word error rate.

    Args:
        outputs (Tensor): The output from the network.
        targets (Tensor): Ground truth labels.
        batch_size (optional[int]): Batch size if target and output has been flattend.
        blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.

    Returns:
        float: The wer for the batch.

    """
    if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
        targets = rearrange(targets, "(b t) -> b t", b=batch_size)
        outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)

    target_lengths = torch.full(
        size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
    )
    decoded_predictions, decoded_targets = greedy_decoder(
        outputs, targets, target_lengths, blank_label=blank_label,
    )

    lev_dist = 0

    for prediction, target in zip(decoded_predictions, decoded_targets):
        prediction = "".join(prediction)
        target = "".join(target)

        b = set(prediction.split() + target.split())
        word2char = dict(zip(b, range(len(b))))

        w1 = [chr(word2char[w]) for w in prediction.split()]
        w2 = [chr(word2char[w]) for w in target.split()]

        lev_dist += Lev.distance("".join(w1), "".join(w2))

    return lev_dist / len(decoded_predictions)