diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-06-23 22:39:54 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-06-23 22:39:54 +0200 |
commit | 7c4de6d88664d2ea1b084f316a11896dde3e1150 (patch) | |
tree | cbde7e64c8064e9b523dfb65cd6c487d061ec805 /src/text_recognizer/models/util.py | |
parent | a7a9ce59fc37317dd74c3a52caf7c4e68e570ee8 (diff) |
latest
Diffstat (limited to 'src/text_recognizer/models/util.py')
-rw-r--r-- | src/text_recognizer/models/util.py | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/src/text_recognizer/models/util.py b/src/text_recognizer/models/util.py new file mode 100644 index 0000000..905fe7b --- /dev/null +++ b/src/text_recognizer/models/util.py @@ -0,0 +1,19 @@ +"""Utility functions for models.""" + +import torch + + +def accuracy(outputs: torch.Tensor, labels: torch.Tensro) -> float: + """Short summary. + + Args: + outputs (torch.Tensor): The output from the network. + labels (torch.Tensor): Ground truth labels. + + Returns: + float: The accuracy for the batch. + + """ + _, predicted = torch.max(outputs.data, dim=1) + acc = (predicted == labels).sum().item() / labels.shape[0] + return acc |