diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
commit | 7e8e54e84c63171e748bbf09516fd517e6821ace (patch) | |
tree | 996093f75a5d488dddf7ea1f159ed343a561ef89 /text_recognizer/networks/metrics.py | |
parent | b0719d84138b6bbe5f04a4982dfca673aea1a368 (diff) |
Inital commit for refactoring to lightning
Diffstat (limited to 'text_recognizer/networks/metrics.py')
-rw-r--r-- | text_recognizer/networks/metrics.py | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/text_recognizer/networks/metrics.py b/text_recognizer/networks/metrics.py new file mode 100644 index 0000000..2605731 --- /dev/null +++ b/text_recognizer/networks/metrics.py @@ -0,0 +1,123 @@ +"""Utility functions for models.""" +from typing import Optional + +from einops import rearrange +import Levenshtein as Lev +import torch +from torch import Tensor + +from text_recognizer.networks import greedy_decoder + + +def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float: + """Computes the accuracy. + + Args: + outputs (Tensor): The output from the network. + labels (Tensor): Ground truth labels. + pad_index (int): Padding index. + + Returns: + float: The accuracy for the batch. + + """ + + _, predicted = torch.max(outputs, dim=-1) + + # Mask out the pad tokens + mask = labels != pad_index + + predicted *= mask + labels *= mask + + acc = (predicted == labels).sum().float() / labels.shape[0] + acc = acc.item() + return acc + + +def cer( + outputs: Tensor, + targets: Tensor, + batch_size: Optional[int] = None, + blank_label: Optional[int] = int, +) -> float: + """Computes the character error rate. + + Args: + outputs (Tensor): The output from the network. + targets (Tensor): Ground truth labels. + batch_size (Optional[int]): Batch size if target and output has been flattend. + blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. + + Returns: + float: The cer for the batch. + + """ + if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: + targets = rearrange(targets, "(b t) -> b t", b=batch_size) + outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) + + target_lengths = torch.full( + size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, + ) + decoded_predictions, decoded_targets = greedy_decoder( + outputs, targets, target_lengths, blank_label=blank_label, + ) + + lev_dist = 0 + + for prediction, target in zip(decoded_predictions, decoded_targets): + prediction = "".join(prediction) + target = "".join(target) + prediction, target = ( + prediction.replace(" ", ""), + target.replace(" ", ""), + ) + lev_dist += Lev.distance(prediction, target) + return lev_dist / len(decoded_predictions) + + +def wer( + outputs: Tensor, + targets: Tensor, + batch_size: Optional[int] = None, + blank_label: Optional[int] = int, +) -> float: + """Computes the Word error rate. + + Args: + outputs (Tensor): The output from the network. + targets (Tensor): Ground truth labels. + batch_size (optional[int]): Batch size if target and output has been flattend. + blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. + + Returns: + float: The wer for the batch. + + """ + if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: + targets = rearrange(targets, "(b t) -> b t", b=batch_size) + outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) + + target_lengths = torch.full( + size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, + ) + decoded_predictions, decoded_targets = greedy_decoder( + outputs, targets, target_lengths, blank_label=blank_label, + ) + + lev_dist = 0 + + for prediction, target in zip(decoded_predictions, decoded_targets): + prediction = "".join(prediction) + target = "".join(target) + + b = set(prediction.split() + target.split()) + word2char = dict(zip(b, range(len(b)))) + + w1 = [chr(word2char[w]) for w in prediction.split()] + w2 = [chr(word2char[w]) for w in target.split()] + + lev_dist += Lev.distance("".join(w1), "".join(w2)) + + return lev_dist / len(decoded_predictions) |