summaryrefslogtreecommitdiff
path: root/text_recognizer/model/greedy_decoder.py
blob: 5cbbb66b4c99de800f62d650c33a5b4e257f2da3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""Greedy decoder."""
from typing import Type
from text_recognizer.data.tokenizer import Tokenizer
import torch
from torch import nn, Tensor


class GreedyDecoder:
    def __init__(
        self,
        network: Type[nn.Module],
        tokenizer: Tokenizer,
        max_output_len: int = 682,
    ) -> None:
        self.network = network
        self.start_index = tokenizer.start_index
        self.end_index = tokenizer.end_index
        self.pad_index = tokenizer.pad_index
        self.max_output_len = max_output_len

    def __call__(self, x: Tensor) -> Tensor:
        bsz = x.shape[0]

        # Encode image(s) to latent vectors.
        img_features = self.network.encode(x)

        # Create a placeholder matrix for storing outputs from the network
        indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
        indecies[:, 0] = self.start_index

        try:
            for i in range(1, self.max_output_len):
                tokens = indecies[:, :i]  # (B, Sy)
                logits = self.network.decode(tokens, img_features)  # (B, C, Sy)
                indecies_ = torch.argmax(logits, dim=1)  # (B, Sy)
                indecies[:, i : i + 1] = indecies_[:, -1:]

                # Early stopping of prediction loop if token is end or padding token.
                if (
                    (indecies[:, i - 1] == self.end_index)
                    | (indecies[:, i - 1] == self.pad_index)
                ).all():
                    break

            # Set all tokens after end token to pad token.
            for i in range(1, self.max_output_len):
                idx = (indecies[:, i - 1] == self.end_index) | (
                    indecies[:, i - 1] == self.pad_index
                )
                indecies[idx, i] = self.pad_index

            return indecies
        except Exception:
            # TODO: investigate this error more
            print(x.shape)
            # print(indecies)
            print(indecies.shape)
            print(img_features.shape)