summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/metrics.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-12 23:42:03 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-12 23:42:03 +0100
commit8fdb6435e15703fa5b76df19728d905650ee1aef (patch)
treebe3bec9e5cab4ef7f9d94528d102e57ce9b16c3f /src/text_recognizer/models/metrics.py
parentdc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (diff)
parent6cb08a110620ee09fe9d8a5d008197a801d025df (diff)
Working cnn transformer.
Diffstat (limited to 'src/text_recognizer/models/metrics.py')
-rw-r--r--src/text_recognizer/models/metrics.py21
1 files changed, 18 insertions, 3 deletions
diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py
index 42c3c6e..af9adb5 100644
--- a/src/text_recognizer/models/metrics.py
+++ b/src/text_recognizer/models/metrics.py
@@ -6,7 +6,23 @@ from torch import Tensor
from text_recognizer.networks import greedy_decoder
-def accuracy(outputs: Tensor, labels: Tensor) -> float:
+def accuracy_ignore_pad(
+ output: Tensor,
+ target: Tensor,
+ pad_index: int = 79,
+ eos_index: int = 81,
+ seq_len: int = 97,
+) -> float:
+ """Sets all predictions after eos to pad."""
+ start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1)
+ end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len)
+ for start, stop in zip(start_indices, end_indices):
+ output[start + 1 : stop] = pad_index
+
+ return accuracy(output, target)
+
+
+def accuracy(outputs: Tensor, labels: Tensor,) -> float:
"""Computes the accuracy.
Args:
@@ -17,10 +33,9 @@ def accuracy(outputs: Tensor, labels: Tensor) -> float:
float: The accuracy for the batch.
"""
- # eos_index = torch.nonzero(labels == eos, as_tuple=False)
- # eos_index = eos_index[0].item() if eos_index.nelement() else -1
_, predicted = torch.max(outputs, dim=-1)
+
acc = (predicted == labels).sum().float() / labels.shape[0]
acc = acc.item()
return acc