summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/ctc.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /text_recognizer/networks/ctc.py
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'text_recognizer/networks/ctc.py')
-rw-r--r--text_recognizer/networks/ctc.py58
1 files changed, 58 insertions, 0 deletions
diff --git a/text_recognizer/networks/ctc.py b/text_recognizer/networks/ctc.py
new file mode 100644
index 0000000..af9b700
--- /dev/null
+++ b/text_recognizer/networks/ctc.py
@@ -0,0 +1,58 @@
+"""Decodes the CTC output."""
+from typing import Callable, List, Optional, Tuple
+
+from einops import rearrange
+import torch
+from torch import Tensor
+
+from text_recognizer.datasets.util import EmnistMapper
+
+
+def greedy_decoder(
+ predictions: Tensor,
+ targets: Optional[Tensor] = None,
+ target_lengths: Optional[Tensor] = None,
+ character_mapper: Optional[Callable] = None,
+ blank_label: int = 79,
+ collapse_repeated: bool = True,
+) -> Tuple[List[str], List[str]]:
+ """Greedy CTC decoder.
+
+ Args:
+ predictions (Tensor): Tenor of network predictions, shape [time, batch, classes].
+ targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None.
+ target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None.
+ character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults
+ to None.
+ blank_label (int): The blank character to be ignored. Defaults to 80.
+ collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True.
+
+ Returns:
+ Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets.
+
+ """
+
+ if character_mapper is None:
+ character_mapper = EmnistMapper(pad_token="_") # noqa: S106
+
+ predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t")
+ decoded_predictions = []
+ decoded_targets = []
+ for i, prediction in enumerate(predictions):
+ decoded_prediction = []
+ decoded_target = []
+ if targets is not None and target_lengths is not None:
+ for target_index in targets[i][: target_lengths[i]]:
+ if target_index == blank_label:
+ continue
+ decoded_target.append(character_mapper(int(target_index)))
+ decoded_targets.append(decoded_target)
+ for j, index in enumerate(prediction):
+ if index != blank_label:
+ if collapse_repeated and j != 0 and index == prediction[j - 1]:
+ continue
+ decoded_prediction.append(index.item())
+ decoded_predictions.append(
+ [character_mapper(int(pred_index)) for pred_index in decoded_prediction]
+ )
+ return decoded_predictions, decoded_targets