summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/crnn.py
blob: 778e232af83bd4fc72ce013ea5f24c521e22caa1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""CRNN for handwritten text recognition."""
from typing import Dict, Tuple

from einops import rearrange, reduce
from einops.layers.torch import Rearrange
from loguru import logger
from torch import nn
from torch import Tensor

from text_recognizer.networks.util import configure_backbone


class ConvolutionalRecurrentNetwork(nn.Module):
    """Network that takes a image of a text line and predicts tokens that are in the image."""

    def __init__(
        self,
        backbone: str,
        backbone_args: Dict = None,
        input_size: int = 128,
        hidden_size: int = 128,
        bidirectional: bool = False,
        num_layers: int = 1,
        num_classes: int = 80,
        patch_size: Tuple[int, int] = (28, 28),
        stride: Tuple[int, int] = (1, 14),
        recurrent_cell: str = "lstm",
        avg_pool: bool = False,
        use_sliding_window: bool = True,
    ) -> None:
        super().__init__()
        self.backbone_args = backbone_args or {}
        self.patch_size = patch_size
        self.stride = stride
        self.sliding_window = (
            self._configure_sliding_window() if use_sliding_window else None
        )
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.backbone = configure_backbone(backbone, backbone_args)
        self.bidirectional = bidirectional
        self.avg_pool = avg_pool

        if recurrent_cell.upper() in ["LSTM", "GRU"]:
            recurrent_cell = getattr(nn, recurrent_cell)
        else:
            logger.warning(
                f"Option {recurrent_cell} not valid, defaulting to LSTM cell."
            )
            recurrent_cell = nn.LSTM

        self.rnn = recurrent_cell(
            input_size=self.input_size,
            hidden_size=self.hidden_size,
            bidirectional=bidirectional,
            num_layers=num_layers,
        )

        decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size

        self.decoder = nn.Sequential(
            nn.Linear(in_features=decoder_size, out_features=num_classes),
            nn.LogSoftmax(dim=2),
        )

    def _configure_sliding_window(self) -> nn.Sequential:
        return nn.Sequential(
            nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
            Rearrange(
                "b (c h w) t -> b t c h w",
                h=self.patch_size[0],
                w=self.patch_size[1],
                c=1,
            ),
        )

    def forward(self, x: Tensor) -> Tensor:
        """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
        if len(x.shape) < 4:
            x = x[(None,) * (4 - len(x.shape))]

        if self.sliding_window is not None:
            # Create image patches with a sliding window kernel.
            x = self.sliding_window(x)

            # Rearrange from a sequence of patches for feedforward network.
            b, t = x.shape[:2]
            x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)

            x = self.backbone(x)

            # Average pooling.
            if self.avg_pool:
                x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
            else:
                x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
        else:
            # Encode the entire image with a CNN, and use the channels as temporal dimension.
            x = self.backbone(x)
            x = rearrange(x, "b c h w -> b w c h")
            if self.adaptive_pool is not None:
                x = self.adaptive_pool(x)
            x = x.squeeze(3)

        # Sequence predictions.
        x, _ = self.rnn(x)

        # Sequence to classification layer.
        x = self.decoder(x)
        return x