summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/ctc.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /src/text_recognizer/networks/ctc.py
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'src/text_recognizer/networks/ctc.py')
-rw-r--r--src/text_recognizer/networks/ctc.py58
1 files changed, 0 insertions, 58 deletions
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py
deleted file mode 100644
index af9b700..0000000
--- a/src/text_recognizer/networks/ctc.py
+++ /dev/null
@@ -1,58 +0,0 @@
-"""Decodes the CTC output."""
-from typing import Callable, List, Optional, Tuple
-
-from einops import rearrange
-import torch
-from torch import Tensor
-
-from text_recognizer.datasets.util import EmnistMapper
-
-
-def greedy_decoder(
- predictions: Tensor,
- targets: Optional[Tensor] = None,
- target_lengths: Optional[Tensor] = None,
- character_mapper: Optional[Callable] = None,
- blank_label: int = 79,
- collapse_repeated: bool = True,
-) -> Tuple[List[str], List[str]]:
- """Greedy CTC decoder.
-
- Args:
- predictions (Tensor): Tenor of network predictions, shape [time, batch, classes].
- targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None.
- target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None.
- character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults
- to None.
- blank_label (int): The blank character to be ignored. Defaults to 80.
- collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True.
-
- Returns:
- Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets.
-
- """
-
- if character_mapper is None:
- character_mapper = EmnistMapper(pad_token="_") # noqa: S106
-
- predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t")
- decoded_predictions = []
- decoded_targets = []
- for i, prediction in enumerate(predictions):
- decoded_prediction = []
- decoded_target = []
- if targets is not None and target_lengths is not None:
- for target_index in targets[i][: target_lengths[i]]:
- if target_index == blank_label:
- continue
- decoded_target.append(character_mapper(int(target_index)))
- decoded_targets.append(decoded_target)
- for j, index in enumerate(prediction):
- if index != blank_label:
- if collapse_repeated and j != 0 and index == prediction[j - 1]:
- continue
- decoded_prediction.append(index.item())
- decoded_predictions.append(
- [character_mapper(int(pred_index)) for pred_index in decoded_prediction]
- )
- return decoded_predictions, decoded_targets