diff options
Diffstat (limited to 'src/text_recognizer/networks/metrics.py')
-rw-r--r-- | src/text_recognizer/networks/metrics.py | 25 |
1 files changed, 8 insertions, 17 deletions
diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py index af9adb5..ffad792 100644 --- a/src/text_recognizer/networks/metrics.py +++ b/src/text_recognizer/networks/metrics.py @@ -6,28 +6,13 @@ from torch import Tensor from text_recognizer.networks import greedy_decoder -def accuracy_ignore_pad( - output: Tensor, - target: Tensor, - pad_index: int = 79, - eos_index: int = 81, - seq_len: int = 97, -) -> float: - """Sets all predictions after eos to pad.""" - start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1) - end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len) - for start, stop in zip(start_indices, end_indices): - output[start + 1 : stop] = pad_index - - return accuracy(output, target) - - -def accuracy(outputs: Tensor, labels: Tensor,) -> float: +def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float: """Computes the accuracy. Args: outputs (Tensor): The output from the network. labels (Tensor): Ground truth labels. + pad_index (int): Padding index. Returns: float: The accuracy for the batch. @@ -36,6 +21,12 @@ def accuracy(outputs: Tensor, labels: Tensor,) -> float: _, predicted = torch.max(outputs, dim=-1) + # Mask out the pad tokens + mask = labels != pad_index + + predicted *= mask + labels *= mask + acc = (predicted == labels).sum().float() / labels.shape[0] acc = acc.item() return acc |