diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-08 23:14:23 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-08 23:14:23 +0200 |
commit | e1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (patch) | |
tree | 70b482f890c9ad2be104f0bff8f2172e8411a2be /src/text_recognizer/networks/line_lstm_ctc.py | |
parent | fe23001b6588e6e6e9e2c5a99b72f3445cf5206f (diff) |
IAM datasets implemented.
Diffstat (limited to 'src/text_recognizer/networks/line_lstm_ctc.py')
-rw-r--r-- | src/text_recognizer/networks/line_lstm_ctc.py | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py index 2e2c3a5..988b615 100644 --- a/src/text_recognizer/networks/line_lstm_ctc.py +++ b/src/text_recognizer/networks/line_lstm_ctc.py @@ -1,5 +1,81 @@ """LSTM with CTC for handwritten text recognition within a line.""" +import importlib +from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from einops import rearrange, reduce +from einops.layers.torch import Rearrange, Reduce import torch from torch import nn from torch import Tensor + + +class LineRecurrentNetwork(nn.Module): + """Network that takes a image of a text line and predicts tokens that are in the image.""" + + def __init__( + self, + encoder: str, + encoder_args: Dict = None, + flatten: bool = True, + input_size: int = 128, + hidden_size: int = 128, + num_layers: int = 1, + num_classes: int = 80, + patch_size: Tuple[int, int] = (28, 28), + stride: Tuple[int, int] = (1, 14), + ) -> None: + super().__init__() + self.encoder_args = encoder_args or {} + self.patch_size = patch_size + self.stride = stride + self.sliding_window = self._configure_sliding_window() + self.input_size = input_size + self.hidden_size = hidden_size + self.encoder = self._configure_encoder(encoder) + self.flatten = flatten + self.rnn = nn.LSTM( + input_size=self.input_size, + hidden_size=self.hidden_size, + num_layers=num_layers, + ) + self.decoder = nn.Sequential( + nn.Linear(in_features=self.hidden_size, out_features=num_classes), + nn.LogSoftmax(dim=2), + ) + + def _configure_encoder(self, encoder: str) -> Type[nn.Module]: + network_module = importlib.import_module("text_recognizer.networks") + encoder_ = getattr(network_module, encoder) + return encoder_(**self.encoder_args) + + def _configure_sliding_window(self) -> nn.Sequential: + return nn.Sequential( + nn.Unfold(kernel_size=self.patch_size, stride=self.stride), + Rearrange( + "b (c h w) t -> b t c h w", + h=self.patch_size[0], + w=self.patch_size[1], + c=1, + ), + ) + + def forward(self, x: Tensor) -> Tensor: + """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM.""" + if len(x.shape) == 3: + x = x.unsqueeze(0) + x = self.sliding_window(x) + + # Rearrange from a sequence of patches for feedforward network. + b, t = x.shape[:2] + x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) + x = self.encoder(x) + + # Avgerage pooling. + x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) if self.flatten else x + + # Sequence predictions. + x, _ = self.rnn(x) + + # Sequence to classifcation layer. + x = self.decoder(x) + return x |