1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
|
"""Utility functions for models."""
import Levenshtein as Lev
import torch
from torch import Tensor
from text_recognizer.networks import greedy_decoder
def accuracy(outputs: Tensor, labels: Tensor) -> float:
"""Computes the accuracy.
Args:
outputs (Tensor): The output from the network.
labels (Tensor): Ground truth labels.
Returns:
float: The accuracy for the batch.
"""
# eos_index = torch.nonzero(labels == eos, as_tuple=False)
# eos_index = eos_index[0].item() if eos_index.nelement() else -1
_, predicted = torch.max(outputs, dim=-1)
acc = (predicted == labels).sum().float() / labels.shape[0]
acc = acc.item()
return acc
def cer(outputs: Tensor, targets: Tensor) -> float:
"""Computes the character error rate.
Args:
outputs (Tensor): The output from the network.
targets (Tensor): Ground truth labels.
Returns:
float: The cer for the batch.
"""
target_lengths = torch.full(
size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
)
decoded_predictions, decoded_targets = greedy_decoder(
outputs, targets, target_lengths
)
lev_dist = 0
for prediction, target in zip(decoded_predictions, decoded_targets):
prediction = "".join(prediction)
target = "".join(target)
prediction, target = (
prediction.replace(" ", ""),
target.replace(" ", ""),
)
lev_dist += Lev.distance(prediction, target)
return lev_dist / len(decoded_predictions)
def wer(outputs: Tensor, targets: Tensor) -> float:
"""Computes the Word error rate.
Args:
outputs (Tensor): The output from the network.
targets (Tensor): Ground truth labels.
Returns:
float: The wer for the batch.
"""
target_lengths = torch.full(
size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
)
decoded_predictions, decoded_targets = greedy_decoder(
outputs, targets, target_lengths
)
lev_dist = 0
for prediction, target in zip(decoded_predictions, decoded_targets):
prediction = "".join(prediction)
target = "".join(target)
b = set(prediction.split() + target.split())
word2char = dict(zip(b, range(len(b))))
w1 = [chr(word2char[w]) for w in prediction.split()]
w2 = [chr(word2char[w]) for w in target.split()]
lev_dist += Lev.distance("".join(w1), "".join(w2))
return lev_dist / len(decoded_predictions)
|