summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/crnn.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
commitdc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (patch)
tree1b5fc0d06952e13727e85c4f973a26d277068453 /src/text_recognizer/networks/crnn.py
parente181195a699d7fa237f256d90ab4dedffc03d405 (diff)
new updates
Diffstat (limited to 'src/text_recognizer/networks/crnn.py')
-rw-r--r--src/text_recognizer/networks/crnn.py108
1 files changed, 108 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py
new file mode 100644
index 0000000..9747429
--- /dev/null
+++ b/src/text_recognizer/networks/crnn.py
@@ -0,0 +1,108 @@
+"""LSTM with CTC for handwritten text recognition within a line."""
+from typing import Dict, Tuple
+
+from einops import rearrange, reduce
+from einops.layers.torch import Rearrange
+from loguru import logger
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import configure_backbone
+
+
+class ConvolutionalRecurrentNetwork(nn.Module):
+ """Network that takes a image of a text line and predicts tokens that are in the image."""
+
+ def __init__(
+ self,
+ backbone: str,
+ backbone_args: Dict = None,
+ input_size: int = 128,
+ hidden_size: int = 128,
+ bidirectional: bool = False,
+ num_layers: int = 1,
+ num_classes: int = 80,
+ patch_size: Tuple[int, int] = (28, 28),
+ stride: Tuple[int, int] = (1, 14),
+ recurrent_cell: str = "lstm",
+ avg_pool: bool = False,
+ use_sliding_window: bool = True,
+ ) -> None:
+ super().__init__()
+ self.backbone_args = backbone_args or {}
+ self.patch_size = patch_size
+ self.stride = stride
+ self.sliding_window = (
+ self._configure_sliding_window() if use_sliding_window else None
+ )
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.backbone = configure_backbone(backbone, backbone_args)
+ self.bidirectional = bidirectional
+ self.avg_pool = avg_pool
+
+ if recurrent_cell.upper() in ["LSTM", "GRU"]:
+ recurrent_cell = getattr(nn, recurrent_cell)
+ else:
+ logger.warning(
+ f"Option {recurrent_cell} not valid, defaulting to LSTM cell."
+ )
+ recurrent_cell = nn.LSTM
+
+ self.rnn = recurrent_cell(
+ input_size=self.input_size,
+ hidden_size=self.hidden_size,
+ bidirectional=bidirectional,
+ num_layers=num_layers,
+ )
+
+ decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size
+
+ self.decoder = nn.Sequential(
+ nn.Linear(in_features=decoder_size, out_features=num_classes),
+ nn.LogSoftmax(dim=2),
+ )
+
+ def _configure_sliding_window(self) -> nn.Sequential:
+ return nn.Sequential(
+ nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
+ Rearrange(
+ "b (c h w) t -> b t c h w",
+ h=self.patch_size[0],
+ w=self.patch_size[1],
+ c=1,
+ ),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+
+ if self.sliding_window is not None:
+ # Create image patches with a sliding window kernel.
+ x = self.sliding_window(x)
+
+ # Rearrange from a sequence of patches for feedforward network.
+ b, t = x.shape[:2]
+ x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
+
+ x = self.backbone(x)
+
+ # Avgerage pooling.
+ if self.avg_pool:
+ x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
+ else:
+ x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
+ else:
+ # Encode the entire image with a CNN, and use the channels as temporal dimension.
+ b = x.shape[0]
+ x = self.backbone(x)
+ x = rearrange(x, "b c h w -> c b (h w)", b=b)
+
+ # Sequence predictions.
+ x, _ = self.rnn(x)
+
+ # Sequence to classifcation layer.
+ x = self.decoder(x)
+ return x