diff options
Diffstat (limited to 'src/text_recognizer/networks/metrics.py')
-rw-r--r-- | src/text_recognizer/networks/metrics.py | 33 |
1 files changed, 29 insertions, 4 deletions
diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py index ffad792..2605731 100644 --- a/src/text_recognizer/networks/metrics.py +++ b/src/text_recognizer/networks/metrics.py @@ -1,4 +1,7 @@ """Utility functions for models.""" +from typing import Optional + +from einops import rearrange import Levenshtein as Lev import torch from torch import Tensor @@ -32,22 +35,33 @@ def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float: return acc -def cer(outputs: Tensor, targets: Tensor) -> float: +def cer( + outputs: Tensor, + targets: Tensor, + batch_size: Optional[int] = None, + blank_label: Optional[int] = int, +) -> float: """Computes the character error rate. Args: outputs (Tensor): The output from the network. targets (Tensor): Ground truth labels. + batch_size (Optional[int]): Batch size if target and output has been flattend. + blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. Returns: float: The cer for the batch. """ + if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: + targets = rearrange(targets, "(b t) -> b t", b=batch_size) + outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) + target_lengths = torch.full( size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, ) decoded_predictions, decoded_targets = greedy_decoder( - outputs, targets, target_lengths + outputs, targets, target_lengths, blank_label=blank_label, ) lev_dist = 0 @@ -63,22 +77,33 @@ def cer(outputs: Tensor, targets: Tensor) -> float: return lev_dist / len(decoded_predictions) -def wer(outputs: Tensor, targets: Tensor) -> float: +def wer( + outputs: Tensor, + targets: Tensor, + batch_size: Optional[int] = None, + blank_label: Optional[int] = int, +) -> float: """Computes the Word error rate. Args: outputs (Tensor): The output from the network. targets (Tensor): Ground truth labels. + batch_size (optional[int]): Batch size if target and output has been flattend. + blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. Returns: float: The wer for the batch. """ + if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: + targets = rearrange(targets, "(b t) -> b t", b=batch_size) + outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) + target_lengths = torch.full( size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, ) decoded_predictions, decoded_targets = greedy_decoder( - outputs, targets, target_lengths + outputs, targets, target_lengths, blank_label=blank_label, ) lev_dist = 0 |