summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/metrics.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/metrics.py')
-rw-r--r--src/text_recognizer/networks/metrics.py25
1 files changed, 8 insertions, 17 deletions
diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py
index af9adb5..ffad792 100644
--- a/src/text_recognizer/networks/metrics.py
+++ b/src/text_recognizer/networks/metrics.py
@@ -6,28 +6,13 @@ from torch import Tensor
from text_recognizer.networks import greedy_decoder
-def accuracy_ignore_pad(
- output: Tensor,
- target: Tensor,
- pad_index: int = 79,
- eos_index: int = 81,
- seq_len: int = 97,
-) -> float:
- """Sets all predictions after eos to pad."""
- start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1)
- end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len)
- for start, stop in zip(start_indices, end_indices):
- output[start + 1 : stop] = pad_index
-
- return accuracy(output, target)
-
-
-def accuracy(outputs: Tensor, labels: Tensor,) -> float:
+def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float:
"""Computes the accuracy.
Args:
outputs (Tensor): The output from the network.
labels (Tensor): Ground truth labels.
+ pad_index (int): Padding index.
Returns:
float: The accuracy for the batch.
@@ -36,6 +21,12 @@ def accuracy(outputs: Tensor, labels: Tensor,) -> float:
_, predicted = torch.max(outputs, dim=-1)
+ # Mask out the pad tokens
+ mask = labels != pad_index
+
+ predicted *= mask
+ labels *= mask
+
acc = (predicted == labels).sum().float() / labels.shape[0]
acc = acc.item()
return acc