summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/metrics.py
blob: ac8d68e29a4b60ff7ce4c468181fe0948386d50a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""Utility functions for models."""

import torch


def accuracy(outputs: torch.Tensor, labels: torch.Tensor) -> float:
    """Computes the accuracy.

    Args:
        outputs (torch.Tensor): The output from the network.
        labels (torch.Tensor): Ground truth labels.

    Returns:
        float: The accuracy for the batch.

    """
    _, predicted = torch.max(outputs.data, dim=1)
    acc = (predicted == labels).sum().item() / labels.shape[0]
    return acc