summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/ctc.py
blob: fc0d21d1bb851f9a327ff83edb63556639e742cd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""Decodes the CTC output."""
from typing import Callable, List, Optional, Tuple

from einops import rearrange
import torch
from torch import Tensor

from text_recognizer.datasets import EmnistMapper


def greedy_decoder(
    predictions: Tensor,
    targets: Optional[Tensor] = None,
    target_lengths: Optional[Tensor] = None,
    character_mapper: Optional[Callable] = None,
    blank_label: int = 79,
    collapse_repeated: bool = True,
) -> Tuple[List[str], List[str]]:
    """Greedy CTC decoder.

    Args:
        predictions (Tensor): Tenor of network predictions, shape [time, batch, classes].
        targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None.
        target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None.
        character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters.  Defaults
            to None.
        blank_label (int): The blank character to be ignored. Defaults to 79.
        collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True.

    Returns:
        Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets.

    """

    if character_mapper is None:
        character_mapper = EmnistMapper()

    predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t")
    decoded_predictions = []
    decoded_targets = []
    for i, prediction in enumerate(predictions):
        decoded_prediction = []
        decoded_target = []
        if targets is not None and target_lengths is not None:
            for target_index in targets[i][: target_lengths[i]]:
                if target_index == blank_label:
                    continue
                decoded_target.append(character_mapper(int(target_index)))
            decoded_targets.append(decoded_target)
        for j, index in enumerate(prediction):
            if index != blank_label:
                if collapse_repeated and j != 0 and index == prediction[j - 1]:
                    continue
                decoded_prediction.append(index.item())
        decoded_predictions.append(
            [character_mapper(int(pred_index)) for pred_index in decoded_prediction]
        )
    return decoded_predictions, decoded_targets