blob: e2a30a91f46b3974f81eaf8a456dc34cab2886b8 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
"""Utility functions for models."""
import torch
def accuracy(outputs: torch.Tensor, labels: torch.Tensro) -> float:
"""Computes the accuracy.
Args:
outputs (torch.Tensor): The output from the network.
labels (torch.Tensor): Ground truth labels.
Returns:
float: The accuracy for the batch.
"""
_, predicted = torch.max(outputs.data, dim=1)
acc = (predicted == labels).sum().item() / labels.shape[0]
return acc
|