summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/ctc.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/ctc.py')
-rw-r--r--src/text_recognizer/networks/ctc.py66
1 files changed, 57 insertions, 9 deletions
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py
index 00ad47e..fc0d21d 100644
--- a/src/text_recognizer/networks/ctc.py
+++ b/src/text_recognizer/networks/ctc.py
@@ -1,10 +1,58 @@
"""Decodes the CTC output."""
-#
-# from typing import Tuple
-# import torch
-#
-#
-# def greedy_decoder(
-# output, labels, label_length, blank_label, collapse_repeated=True
-# ) -> Tuple[torch.Tensor, torch.Tensor]:
-# pass
+from typing import Callable, List, Optional, Tuple
+
+from einops import rearrange
+import torch
+from torch import Tensor
+
+from text_recognizer.datasets import EmnistMapper
+
+
+def greedy_decoder(
+ predictions: Tensor,
+ targets: Optional[Tensor] = None,
+ target_lengths: Optional[Tensor] = None,
+ character_mapper: Optional[Callable] = None,
+ blank_label: int = 79,
+ collapse_repeated: bool = True,
+) -> Tuple[List[str], List[str]]:
+ """Greedy CTC decoder.
+
+ Args:
+ predictions (Tensor): Tenor of network predictions, shape [time, batch, classes].
+ targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None.
+ target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None.
+ character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults
+ to None.
+ blank_label (int): The blank character to be ignored. Defaults to 79.
+ collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True.
+
+ Returns:
+ Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets.
+
+ """
+
+ if character_mapper is None:
+ character_mapper = EmnistMapper()
+
+ predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t")
+ decoded_predictions = []
+ decoded_targets = []
+ for i, prediction in enumerate(predictions):
+ decoded_prediction = []
+ decoded_target = []
+ if targets is not None and target_lengths is not None:
+ for target_index in targets[i][: target_lengths[i]]:
+ if target_index == blank_label:
+ continue
+ decoded_target.append(character_mapper(int(target_index)))
+ decoded_targets.append(decoded_target)
+ for j, index in enumerate(prediction):
+ if index != blank_label:
+ if collapse_repeated and j != 0 and index == prediction[j - 1]:
+ continue
+ decoded_prediction.append(index.item())
+ decoded_predictions.append(
+ [character_mapper(int(pred_index)) for pred_index in decoded_prediction]
+ )
+ return decoded_predictions, decoded_targets