diff options
Diffstat (limited to 'text_recognizer/networks/crnn.py')
-rw-r--r-- | text_recognizer/networks/crnn.py | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/text_recognizer/networks/crnn.py b/text_recognizer/networks/crnn.py new file mode 100644 index 0000000..778e232 --- /dev/null +++ b/text_recognizer/networks/crnn.py @@ -0,0 +1,110 @@ +"""CRNN for handwritten text recognition.""" +from typing import Dict, Tuple + +from einops import rearrange, reduce +from einops.layers.torch import Rearrange +from loguru import logger +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import configure_backbone + + +class ConvolutionalRecurrentNetwork(nn.Module): + """Network that takes a image of a text line and predicts tokens that are in the image.""" + + def __init__( + self, + backbone: str, + backbone_args: Dict = None, + input_size: int = 128, + hidden_size: int = 128, + bidirectional: bool = False, + num_layers: int = 1, + num_classes: int = 80, + patch_size: Tuple[int, int] = (28, 28), + stride: Tuple[int, int] = (1, 14), + recurrent_cell: str = "lstm", + avg_pool: bool = False, + use_sliding_window: bool = True, + ) -> None: + super().__init__() + self.backbone_args = backbone_args or {} + self.patch_size = patch_size + self.stride = stride + self.sliding_window = ( + self._configure_sliding_window() if use_sliding_window else None + ) + self.input_size = input_size + self.hidden_size = hidden_size + self.backbone = configure_backbone(backbone, backbone_args) + self.bidirectional = bidirectional + self.avg_pool = avg_pool + + if recurrent_cell.upper() in ["LSTM", "GRU"]: + recurrent_cell = getattr(nn, recurrent_cell) + else: + logger.warning( + f"Option {recurrent_cell} not valid, defaulting to LSTM cell." + ) + recurrent_cell = nn.LSTM + + self.rnn = recurrent_cell( + input_size=self.input_size, + hidden_size=self.hidden_size, + bidirectional=bidirectional, + num_layers=num_layers, + ) + + decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size + + self.decoder = nn.Sequential( + nn.Linear(in_features=decoder_size, out_features=num_classes), + nn.LogSoftmax(dim=2), + ) + + def _configure_sliding_window(self) -> nn.Sequential: + return nn.Sequential( + nn.Unfold(kernel_size=self.patch_size, stride=self.stride), + Rearrange( + "b (c h w) t -> b t c h w", + h=self.patch_size[0], + w=self.patch_size[1], + c=1, + ), + ) + + def forward(self, x: Tensor) -> Tensor: + """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM.""" + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + + if self.sliding_window is not None: + # Create image patches with a sliding window kernel. + x = self.sliding_window(x) + + # Rearrange from a sequence of patches for feedforward network. + b, t = x.shape[:2] + x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) + + x = self.backbone(x) + + # Average pooling. + if self.avg_pool: + x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) + else: + x = rearrange(x, "(b t) h -> t b h", b=b, t=t) + else: + # Encode the entire image with a CNN, and use the channels as temporal dimension. + x = self.backbone(x) + x = rearrange(x, "b c h w -> b w c h") + if self.adaptive_pool is not None: + x = self.adaptive_pool(x) + x = x.squeeze(3) + + # Sequence predictions. + x, _ = self.rnn(x) + + # Sequence to classification layer. + x = self.decoder(x) + return x |