summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/crnn.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/crnn.py')
-rw-r--r--text_recognizer/networks/crnn.py110
1 files changed, 110 insertions, 0 deletions
diff --git a/text_recognizer/networks/crnn.py b/text_recognizer/networks/crnn.py
new file mode 100644
index 0000000..778e232
--- /dev/null
+++ b/text_recognizer/networks/crnn.py
@@ -0,0 +1,110 @@
+"""CRNN for handwritten text recognition."""
+from typing import Dict, Tuple
+
+from einops import rearrange, reduce
+from einops.layers.torch import Rearrange
+from loguru import logger
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import configure_backbone
+
+
+class ConvolutionalRecurrentNetwork(nn.Module):
+ """Network that takes a image of a text line and predicts tokens that are in the image."""
+
+ def __init__(
+ self,
+ backbone: str,
+ backbone_args: Dict = None,
+ input_size: int = 128,
+ hidden_size: int = 128,
+ bidirectional: bool = False,
+ num_layers: int = 1,
+ num_classes: int = 80,
+ patch_size: Tuple[int, int] = (28, 28),
+ stride: Tuple[int, int] = (1, 14),
+ recurrent_cell: str = "lstm",
+ avg_pool: bool = False,
+ use_sliding_window: bool = True,
+ ) -> None:
+ super().__init__()
+ self.backbone_args = backbone_args or {}
+ self.patch_size = patch_size
+ self.stride = stride
+ self.sliding_window = (
+ self._configure_sliding_window() if use_sliding_window else None
+ )
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.backbone = configure_backbone(backbone, backbone_args)
+ self.bidirectional = bidirectional
+ self.avg_pool = avg_pool
+
+ if recurrent_cell.upper() in ["LSTM", "GRU"]:
+ recurrent_cell = getattr(nn, recurrent_cell)
+ else:
+ logger.warning(
+ f"Option {recurrent_cell} not valid, defaulting to LSTM cell."
+ )
+ recurrent_cell = nn.LSTM
+
+ self.rnn = recurrent_cell(
+ input_size=self.input_size,
+ hidden_size=self.hidden_size,
+ bidirectional=bidirectional,
+ num_layers=num_layers,
+ )
+
+ decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size
+
+ self.decoder = nn.Sequential(
+ nn.Linear(in_features=decoder_size, out_features=num_classes),
+ nn.LogSoftmax(dim=2),
+ )
+
+ def _configure_sliding_window(self) -> nn.Sequential:
+ return nn.Sequential(
+ nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
+ Rearrange(
+ "b (c h w) t -> b t c h w",
+ h=self.patch_size[0],
+ w=self.patch_size[1],
+ c=1,
+ ),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+
+ if self.sliding_window is not None:
+ # Create image patches with a sliding window kernel.
+ x = self.sliding_window(x)
+
+ # Rearrange from a sequence of patches for feedforward network.
+ b, t = x.shape[:2]
+ x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
+
+ x = self.backbone(x)
+
+ # Average pooling.
+ if self.avg_pool:
+ x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
+ else:
+ x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
+ else:
+ # Encode the entire image with a CNN, and use the channels as temporal dimension.
+ x = self.backbone(x)
+ x = rearrange(x, "b c h w -> b w c h")
+ if self.adaptive_pool is not None:
+ x = self.adaptive_pool(x)
+ x = x.squeeze(3)
+
+ # Sequence predictions.
+ x, _ = self.rnn(x)
+
+ # Sequence to classification layer.
+ x = self.decoder(x)
+ return x