1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
|
"""Utility functions for models."""
from typing import Optional
from einops import rearrange
import Levenshtein as Lev
import torch
from torch import Tensor
from text_recognizer.networks import greedy_decoder
def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float:
"""Computes the accuracy.
Args:
outputs (Tensor): The output from the network.
labels (Tensor): Ground truth labels.
pad_index (int): Padding index.
Returns:
float: The accuracy for the batch.
"""
_, predicted = torch.max(outputs, dim=-1)
# Mask out the pad tokens
mask = labels != pad_index
predicted *= mask
labels *= mask
acc = (predicted == labels).sum().float() / labels.shape[0]
acc = acc.item()
return acc
def cer(
outputs: Tensor,
targets: Tensor,
batch_size: Optional[int] = None,
blank_label: Optional[int] = int,
) -> float:
"""Computes the character error rate.
Args:
outputs (Tensor): The output from the network.
targets (Tensor): Ground truth labels.
batch_size (Optional[int]): Batch size if target and output has been flattend.
blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
Returns:
float: The cer for the batch.
"""
if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
targets = rearrange(targets, "(b t) -> b t", b=batch_size)
outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
target_lengths = torch.full(
size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
)
decoded_predictions, decoded_targets = greedy_decoder(
outputs, targets, target_lengths, blank_label=blank_label,
)
lev_dist = 0
for prediction, target in zip(decoded_predictions, decoded_targets):
prediction = "".join(prediction)
target = "".join(target)
prediction, target = (
prediction.replace(" ", ""),
target.replace(" ", ""),
)
lev_dist += Lev.distance(prediction, target)
return lev_dist / len(decoded_predictions)
def wer(
outputs: Tensor,
targets: Tensor,
batch_size: Optional[int] = None,
blank_label: Optional[int] = int,
) -> float:
"""Computes the Word error rate.
Args:
outputs (Tensor): The output from the network.
targets (Tensor): Ground truth labels.
batch_size (optional[int]): Batch size if target and output has been flattend.
blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
Returns:
float: The wer for the batch.
"""
if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
targets = rearrange(targets, "(b t) -> b t", b=batch_size)
outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
target_lengths = torch.full(
size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
)
decoded_predictions, decoded_targets = greedy_decoder(
outputs, targets, target_lengths, blank_label=blank_label,
)
lev_dist = 0
for prediction, target in zip(decoded_predictions, decoded_targets):
prediction = "".join(prediction)
target = "".join(target)
b = set(prediction.split() + target.split())
word2char = dict(zip(b, range(len(b))))
w1 = [chr(word2char[w]) for w in prediction.split()]
w2 = [chr(word2char[w]) for w in target.split()]
lev_dist += Lev.distance("".join(w1), "".join(w2))
return lev_dist / len(decoded_predictions)
|