diff options
Diffstat (limited to 'src/text_recognizer/models/metrics.py')
-rw-r--r-- | src/text_recognizer/models/metrics.py | 21 |
1 files changed, 18 insertions, 3 deletions
diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py index 42c3c6e..af9adb5 100644 --- a/src/text_recognizer/models/metrics.py +++ b/src/text_recognizer/models/metrics.py @@ -6,7 +6,23 @@ from torch import Tensor from text_recognizer.networks import greedy_decoder -def accuracy(outputs: Tensor, labels: Tensor) -> float: +def accuracy_ignore_pad( + output: Tensor, + target: Tensor, + pad_index: int = 79, + eos_index: int = 81, + seq_len: int = 97, +) -> float: + """Sets all predictions after eos to pad.""" + start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1) + end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len) + for start, stop in zip(start_indices, end_indices): + output[start + 1 : stop] = pad_index + + return accuracy(output, target) + + +def accuracy(outputs: Tensor, labels: Tensor,) -> float: """Computes the accuracy. Args: @@ -17,10 +33,9 @@ def accuracy(outputs: Tensor, labels: Tensor) -> float: float: The accuracy for the batch. """ - # eos_index = torch.nonzero(labels == eos, as_tuple=False) - # eos_index = eos_index[0].item() if eos_index.nelement() else -1 _, predicted = torch.max(outputs, dim=-1) + acc = (predicted == labels).sum().float() / labels.shape[0] acc = acc.item() return acc |