diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-07 20:10:54 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-07 20:10:54 +0100 |
commit | ff9a21d333f11a42e67c1963ed67de9c0fda87c9 (patch) | |
tree | afee959135416fe92cf6df377e84fb0a9e9714a0 /src/text_recognizer/networks/metrics.py | |
parent | 25b5d6983d51e0e791b96a76beb7e49f392cd9a8 (diff) |
Minor updates.
Diffstat (limited to 'src/text_recognizer/networks/metrics.py')
-rw-r--r-- | src/text_recognizer/networks/metrics.py | 25 |
1 files changed, 8 insertions, 17 deletions
diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py index af9adb5..ffad792 100644 --- a/src/text_recognizer/networks/metrics.py +++ b/src/text_recognizer/networks/metrics.py @@ -6,28 +6,13 @@ from torch import Tensor from text_recognizer.networks import greedy_decoder -def accuracy_ignore_pad( - output: Tensor, - target: Tensor, - pad_index: int = 79, - eos_index: int = 81, - seq_len: int = 97, -) -> float: - """Sets all predictions after eos to pad.""" - start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1) - end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len) - for start, stop in zip(start_indices, end_indices): - output[start + 1 : stop] = pad_index - - return accuracy(output, target) - - -def accuracy(outputs: Tensor, labels: Tensor,) -> float: +def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float: """Computes the accuracy. Args: outputs (Tensor): The output from the network. labels (Tensor): Ground truth labels. + pad_index (int): Padding index. Returns: float: The accuracy for the batch. @@ -36,6 +21,12 @@ def accuracy(outputs: Tensor, labels: Tensor,) -> float: _, predicted = torch.max(outputs, dim=-1) + # Mask out the pad tokens + mask = labels != pad_index + + predicted *= mask + labels *= mask + acc = (predicted == labels).sum().float() / labels.shape[0] acc = acc.item() return acc |