1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
|
"""LSTM with CTC for handwritten text recognition within a line."""
from typing import Dict, Tuple
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
from loguru import logger
from torch import nn
from torch import Tensor
from text_recognizer.networks.util import configure_backbone
class ConvolutionalRecurrentNetwork(nn.Module):
"""Network that takes a image of a text line and predicts tokens that are in the image."""
def __init__(
self,
backbone: str,
backbone_args: Dict = None,
input_size: int = 128,
hidden_size: int = 128,
bidirectional: bool = False,
num_layers: int = 1,
num_classes: int = 80,
patch_size: Tuple[int, int] = (28, 28),
stride: Tuple[int, int] = (1, 14),
recurrent_cell: str = "lstm",
avg_pool: bool = False,
use_sliding_window: bool = True,
) -> None:
super().__init__()
self.backbone_args = backbone_args or {}
self.patch_size = patch_size
self.stride = stride
self.sliding_window = (
self._configure_sliding_window() if use_sliding_window else None
)
self.input_size = input_size
self.hidden_size = hidden_size
self.backbone = configure_backbone(backbone, backbone_args)
self.bidirectional = bidirectional
self.avg_pool = avg_pool
if recurrent_cell.upper() in ["LSTM", "GRU"]:
recurrent_cell = getattr(nn, recurrent_cell)
else:
logger.warning(
f"Option {recurrent_cell} not valid, defaulting to LSTM cell."
)
recurrent_cell = nn.LSTM
self.rnn = recurrent_cell(
input_size=self.input_size,
hidden_size=self.hidden_size,
bidirectional=bidirectional,
num_layers=num_layers,
)
decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size
self.decoder = nn.Sequential(
nn.Linear(in_features=decoder_size, out_features=num_classes),
nn.LogSoftmax(dim=2),
)
def _configure_sliding_window(self) -> nn.Sequential:
return nn.Sequential(
nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
Rearrange(
"b (c h w) t -> b t c h w",
h=self.patch_size[0],
w=self.patch_size[1],
c=1,
),
)
def forward(self, x: Tensor) -> Tensor:
"""Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
if len(x.shape) < 4:
x = x[(None,) * (4 - len(x.shape))]
if self.sliding_window is not None:
# Create image patches with a sliding window kernel.
x = self.sliding_window(x)
# Rearrange from a sequence of patches for feedforward network.
b, t = x.shape[:2]
x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
x = self.backbone(x)
# Avgerage pooling.
if self.avg_pool:
x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
else:
x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
else:
# Encode the entire image with a CNN, and use the channels as temporal dimension.
b = x.shape[0]
x = self.backbone(x)
x = rearrange(x, "b c h w -> c b (h w)", b=b)
# Sequence predictions.
x, _ = self.rnn(x)
# Sequence to classifcation layer.
x = self.decoder(x)
return x
|