diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-08 23:14:23 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-08 23:14:23 +0200 |
commit | e1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (patch) | |
tree | 70b482f890c9ad2be104f0bff8f2172e8411a2be /src/text_recognizer/networks/ctc.py | |
parent | fe23001b6588e6e6e9e2c5a99b72f3445cf5206f (diff) |
IAM datasets implemented.
Diffstat (limited to 'src/text_recognizer/networks/ctc.py')
-rw-r--r-- | src/text_recognizer/networks/ctc.py | 66 |
1 files changed, 57 insertions, 9 deletions
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py index 00ad47e..fc0d21d 100644 --- a/src/text_recognizer/networks/ctc.py +++ b/src/text_recognizer/networks/ctc.py @@ -1,10 +1,58 @@ """Decodes the CTC output.""" -# -# from typing import Tuple -# import torch -# -# -# def greedy_decoder( -# output, labels, label_length, blank_label, collapse_repeated=True -# ) -> Tuple[torch.Tensor, torch.Tensor]: -# pass +from typing import Callable, List, Optional, Tuple + +from einops import rearrange +import torch +from torch import Tensor + +from text_recognizer.datasets import EmnistMapper + + +def greedy_decoder( + predictions: Tensor, + targets: Optional[Tensor] = None, + target_lengths: Optional[Tensor] = None, + character_mapper: Optional[Callable] = None, + blank_label: int = 79, + collapse_repeated: bool = True, +) -> Tuple[List[str], List[str]]: + """Greedy CTC decoder. + + Args: + predictions (Tensor): Tenor of network predictions, shape [time, batch, classes]. + targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None. + target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None. + character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults + to None. + blank_label (int): The blank character to be ignored. Defaults to 79. + collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True. + + Returns: + Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets. + + """ + + if character_mapper is None: + character_mapper = EmnistMapper() + + predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t") + decoded_predictions = [] + decoded_targets = [] + for i, prediction in enumerate(predictions): + decoded_prediction = [] + decoded_target = [] + if targets is not None and target_lengths is not None: + for target_index in targets[i][: target_lengths[i]]: + if target_index == blank_label: + continue + decoded_target.append(character_mapper(int(target_index))) + decoded_targets.append(decoded_target) + for j, index in enumerate(prediction): + if index != blank_label: + if collapse_repeated and j != 0 and index == prediction[j - 1]: + continue + decoded_prediction.append(index.item()) + decoded_predictions.append( + [character_mapper(int(pred_index)) for pred_index in decoded_prediction] + ) + return decoded_predictions, decoded_targets |