summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/metrics.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/metrics.py')
-rw-r--r--src/text_recognizer/networks/metrics.py33
1 files changed, 29 insertions, 4 deletions
diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py
index ffad792..2605731 100644
--- a/src/text_recognizer/networks/metrics.py
+++ b/src/text_recognizer/networks/metrics.py
@@ -1,4 +1,7 @@
"""Utility functions for models."""
+from typing import Optional
+
+from einops import rearrange
import Levenshtein as Lev
import torch
from torch import Tensor
@@ -32,22 +35,33 @@ def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float:
return acc
-def cer(outputs: Tensor, targets: Tensor) -> float:
+def cer(
+ outputs: Tensor,
+ targets: Tensor,
+ batch_size: Optional[int] = None,
+ blank_label: Optional[int] = int,
+) -> float:
"""Computes the character error rate.
Args:
outputs (Tensor): The output from the network.
targets (Tensor): Ground truth labels.
+ batch_size (Optional[int]): Batch size if target and output has been flattend.
+ blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
Returns:
float: The cer for the batch.
"""
+ if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
+ targets = rearrange(targets, "(b t) -> b t", b=batch_size)
+ outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
+
target_lengths = torch.full(
size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
)
decoded_predictions, decoded_targets = greedy_decoder(
- outputs, targets, target_lengths
+ outputs, targets, target_lengths, blank_label=blank_label,
)
lev_dist = 0
@@ -63,22 +77,33 @@ def cer(outputs: Tensor, targets: Tensor) -> float:
return lev_dist / len(decoded_predictions)
-def wer(outputs: Tensor, targets: Tensor) -> float:
+def wer(
+ outputs: Tensor,
+ targets: Tensor,
+ batch_size: Optional[int] = None,
+ blank_label: Optional[int] = int,
+) -> float:
"""Computes the Word error rate.
Args:
outputs (Tensor): The output from the network.
targets (Tensor): Ground truth labels.
+ batch_size (optional[int]): Batch size if target and output has been flattend.
+ blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
Returns:
float: The wer for the batch.
"""
+ if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
+ targets = rearrange(targets, "(b t) -> b t", b=batch_size)
+ outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
+
target_lengths = torch.full(
size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
)
decoded_predictions, decoded_targets = greedy_decoder(
- outputs, targets, target_lengths
+ outputs, targets, target_lengths, blank_label=blank_label,
)
lev_dist = 0