From 53677be4ec14854ea4881b0d78730e0414c8dedd Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 9 Aug 2020 23:24:02 +0200 Subject: Working bash scripts etc. --- src/text_recognizer/networks/ctc.py | 10 +++++++++ src/text_recognizer/networks/lenet.py | 19 +++++----------- src/text_recognizer/networks/line_lstm_ctc.py | 4 ++++ src/text_recognizer/networks/misc.py | 28 ++++++++++++++++++++++++ src/text_recognizer/networks/mlp.py | 9 ++++++-- src/text_recognizer/networks/residual_network.py | 1 + 6 files changed, 55 insertions(+), 16 deletions(-) create mode 100644 src/text_recognizer/networks/ctc.py create mode 100644 src/text_recognizer/networks/line_lstm_ctc.py create mode 100644 src/text_recognizer/networks/misc.py create mode 100644 src/text_recognizer/networks/residual_network.py (limited to 'src/text_recognizer/networks') diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py new file mode 100644 index 0000000..00ad47e --- /dev/null +++ b/src/text_recognizer/networks/ctc.py @@ -0,0 +1,10 @@ +"""Decodes the CTC output.""" +# +# from typing import Tuple +# import torch +# +# +# def greedy_decoder( +# output, labels, label_length, blank_label, collapse_repeated=True +# ) -> Tuple[torch.Tensor, torch.Tensor]: +# pass diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py index 2839a0c..cbc58fc 100644 --- a/src/text_recognizer/networks/lenet.py +++ b/src/text_recognizer/networks/lenet.py @@ -1,24 +1,16 @@ """Defines the LeNet network.""" from typing import Callable, Dict, Optional, Tuple +from einops.layers.torch import Rearrange import torch from torch import nn -class Flatten(nn.Module): - """Flattens a tensor.""" - - def forward(self, x: int) -> torch.Tensor: - """Flattens a tensor for input to a nn.Linear layer.""" - return torch.flatten(x, start_dim=1) - - class LeNet(nn.Module): """LeNet network.""" def __init__( self, - input_size: Tuple[int, ...] = (1, 28, 28), channels: Tuple[int, ...] = (1, 32, 64), kernel_sizes: Tuple[int, ...] = (3, 3, 2), hidden_size: Tuple[int, ...] = (9216, 128), @@ -30,7 +22,6 @@ class LeNet(nn.Module): """The LeNet network. Args: - input_size (Tuple[int, ...]): The input shape of the network. Defaults to (1, 28, 28). channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers. @@ -44,10 +35,9 @@ class LeNet(nn.Module): """ super().__init__() - self._input_size = input_size - if activation_fn is not None: - activation_fn = getattr(nn, activation_fn)(activation_fn_args) + activation_fn_args = activation_fn_args or {} + activation_fn = getattr(nn, activation_fn)(**activation_fn_args) else: activation_fn = nn.ReLU(inplace=True) @@ -66,7 +56,7 @@ class LeNet(nn.Module): activation_fn, nn.MaxPool2d(kernel_sizes[2]), nn.Dropout(p=dropout_rate), - Flatten(), + Rearrange("b c h w -> b (c h w)"), nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]), activation_fn, nn.Dropout(p=dropout_rate), @@ -77,6 +67,7 @@ class LeNet(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """The feedforward.""" + # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) return self.layers(x) diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py new file mode 100644 index 0000000..d704139 --- /dev/null +++ b/src/text_recognizer/networks/line_lstm_ctc.py @@ -0,0 +1,4 @@ +"""LSTM with CTC for handwritten text recognition within a line.""" + +import torch +from torch import nn diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py new file mode 100644 index 0000000..9440f9d --- /dev/null +++ b/src/text_recognizer/networks/misc.py @@ -0,0 +1,28 @@ +"""Miscellaneous neural network functionality.""" +from typing import Tuple + +from einops import rearrange +import torch +from torch.nn import Unfold + + +def sliding_window( + images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int] +) -> torch.Tensor: + """Creates patches of an image. + + Args: + images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width). + patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST. + stride (Tuple[int, int]): The stride of the sliding window. + + Returns: + torch.Tensor: A tensor with the shape (batch, patches, height, width). + + """ + unfold = Unfold(kernel_size=patch_size, stride=stride) + patches = unfold(images) + patches = rearrange( + patches, "b (h w) c -> b c h w", h=patch_size[0], w=patch_size[1] + ) + return patches diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py index d704d99..ac2c825 100644 --- a/src/text_recognizer/networks/mlp.py +++ b/src/text_recognizer/networks/mlp.py @@ -1,6 +1,7 @@ """Defines the MLP network.""" from typing import Callable, Dict, List, Optional, Union +from einops.layers.torch import Rearrange import torch from torch import nn @@ -34,7 +35,8 @@ class MLP(nn.Module): super().__init__() if activation_fn is not None: - activation_fn = getattr(nn, activation_fn)(activation_fn_args) + activation_fn_args = activation_fn_args or {} + activation_fn = getattr(nn, activation_fn)(**activation_fn_args) else: activation_fn = nn.ReLU(inplace=True) @@ -42,6 +44,7 @@ class MLP(nn.Module): hidden_size = [hidden_size] * num_layers self.layers = [ + Rearrange("b c h w -> b (c h w)"), nn.Linear(in_features=input_size, out_features=hidden_size[0]), activation_fn, ] @@ -63,7 +66,9 @@ class MLP(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """The feedforward.""" - x = torch.flatten(x, start_dim=1) + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) == 3: + x = x.unsqueeze(0) return self.layers(x) @property diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py new file mode 100644 index 0000000..23394b0 --- /dev/null +++ b/src/text_recognizer/networks/residual_network.py @@ -0,0 +1 @@ +"""Residual CNN.""" -- cgit v1.2.3-70-g09d2