summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/crnn.py
blob: 9747429bb639e47f45c984c227c85cfe27ef5c20 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""LSTM with CTC for handwritten text recognition within a line."""
from typing import Dict, Tuple

from einops import rearrange, reduce
from einops.layers.torch import Rearrange
from loguru import logger
from torch import nn
from torch import Tensor

from text_recognizer.networks.util import configure_backbone


class ConvolutionalRecurrentNetwork(nn.Module):
    """Network that takes a image of a text line and predicts tokens that are in the image."""

    def __init__(
        self,
        backbone: str,
        backbone_args: Dict = None,
        input_size: int = 128,
        hidden_size: int = 128,
        bidirectional: bool = False,
        num_layers: int = 1,
        num_classes: int = 80,
        patch_size: Tuple[int, int] = (28, 28),
        stride: Tuple[int, int] = (1, 14),
        recurrent_cell: str = "lstm",
        avg_pool: bool = False,
        use_sliding_window: bool = True,
    ) -> None:
        super().__init__()
        self.backbone_args = backbone_args or {}
        self.patch_size = patch_size
        self.stride = stride
        self.sliding_window = (
            self._configure_sliding_window() if use_sliding_window else None
        )
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.backbone = configure_backbone(backbone, backbone_args)
        self.bidirectional = bidirectional
        self.avg_pool = avg_pool

        if recurrent_cell.upper() in ["LSTM", "GRU"]:
            recurrent_cell = getattr(nn, recurrent_cell)
        else:
            logger.warning(
                f"Option {recurrent_cell} not valid, defaulting to LSTM cell."
            )
            recurrent_cell = nn.LSTM

        self.rnn = recurrent_cell(
            input_size=self.input_size,
            hidden_size=self.hidden_size,
            bidirectional=bidirectional,
            num_layers=num_layers,
        )

        decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size

        self.decoder = nn.Sequential(
            nn.Linear(in_features=decoder_size, out_features=num_classes),
            nn.LogSoftmax(dim=2),
        )

    def _configure_sliding_window(self) -> nn.Sequential:
        return nn.Sequential(
            nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
            Rearrange(
                "b (c h w) t -> b t c h w",
                h=self.patch_size[0],
                w=self.patch_size[1],
                c=1,
            ),
        )

    def forward(self, x: Tensor) -> Tensor:
        """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
        if len(x.shape) < 4:
            x = x[(None,) * (4 - len(x.shape))]

        if self.sliding_window is not None:
            # Create image patches with a sliding window kernel.
            x = self.sliding_window(x)

            # Rearrange from a sequence of patches for feedforward network.
            b, t = x.shape[:2]
            x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)

            x = self.backbone(x)

            # Avgerage pooling.
            if self.avg_pool:
                x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
            else:
                x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
        else:
            # Encode the entire image with a CNN, and use the channels as temporal dimension.
            b = x.shape[0]
            x = self.backbone(x)
            x = rearrange(x, "b c h w -> c b (h w)", b=b)

        # Sequence predictions.
        x, _ = self.rnn(x)

        # Sequence to classifcation layer.
        x = self.decoder(x)
        return x