diff options
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r-- | src/text_recognizer/networks/__init__.py | 17 | ||||
-rw-r--r-- | src/text_recognizer/networks/ctc.py | 66 | ||||
-rw-r--r-- | src/text_recognizer/networks/lenet.py | 12 | ||||
-rw-r--r-- | src/text_recognizer/networks/line_lstm_ctc.py | 76 | ||||
-rw-r--r-- | src/text_recognizer/networks/misc.py | 5 | ||||
-rw-r--r-- | src/text_recognizer/networks/mlp.py | 6 | ||||
-rw-r--r-- | src/text_recognizer/networks/residual_network.py | 35 | ||||
-rw-r--r-- | src/text_recognizer/networks/stn.py | 44 | ||||
-rw-r--r-- | src/text_recognizer/networks/wide_resnet.py | 214 |
9 files changed, 435 insertions, 40 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index a83ca35..d20c86a 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -1,6 +1,19 @@ """Network modules.""" +from .ctc import greedy_decoder from .lenet import LeNet +from .line_lstm_ctc import LineRecurrentNetwork +from .misc import sliding_window from .mlp import MLP -from .residual_network import ResidualNetwork +from .residual_network import ResidualNetwork, ResidualNetworkEncoder +from .wide_resnet import WideResidualNetwork -__all__ = ["MLP", "LeNet", "ResidualNetwork"] +__all__ = [ + "greedy_decoder", + "MLP", + "LeNet", + "LineRecurrentNetwork", + "ResidualNetwork", + "ResidualNetworkEncoder", + "sliding_window", + "WideResidualNetwork", +] diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py index 00ad47e..fc0d21d 100644 --- a/src/text_recognizer/networks/ctc.py +++ b/src/text_recognizer/networks/ctc.py @@ -1,10 +1,58 @@ """Decodes the CTC output.""" -# -# from typing import Tuple -# import torch -# -# -# def greedy_decoder( -# output, labels, label_length, blank_label, collapse_repeated=True -# ) -> Tuple[torch.Tensor, torch.Tensor]: -# pass +from typing import Callable, List, Optional, Tuple + +from einops import rearrange +import torch +from torch import Tensor + +from text_recognizer.datasets import EmnistMapper + + +def greedy_decoder( + predictions: Tensor, + targets: Optional[Tensor] = None, + target_lengths: Optional[Tensor] = None, + character_mapper: Optional[Callable] = None, + blank_label: int = 79, + collapse_repeated: bool = True, +) -> Tuple[List[str], List[str]]: + """Greedy CTC decoder. + + Args: + predictions (Tensor): Tenor of network predictions, shape [time, batch, classes]. + targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None. + target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None. + character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults + to None. + blank_label (int): The blank character to be ignored. Defaults to 79. + collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True. + + Returns: + Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets. + + """ + + if character_mapper is None: + character_mapper = EmnistMapper() + + predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t") + decoded_predictions = [] + decoded_targets = [] + for i, prediction in enumerate(predictions): + decoded_prediction = [] + decoded_target = [] + if targets is not None and target_lengths is not None: + for target_index in targets[i][: target_lengths[i]]: + if target_index == blank_label: + continue + decoded_target.append(character_mapper(int(target_index))) + decoded_targets.append(decoded_target) + for j, index in enumerate(prediction): + if index != blank_label: + if collapse_repeated and j != 0 and index == prediction[j - 1]: + continue + decoded_prediction.append(index.item()) + decoded_predictions.append( + [character_mapper(int(pred_index)) for pred_index in decoded_prediction] + ) + return decoded_predictions, decoded_targets diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py index 91d3f2c..53c575e 100644 --- a/src/text_recognizer/networks/lenet.py +++ b/src/text_recognizer/networks/lenet.py @@ -1,4 +1,4 @@ -"""Defines the LeNet network.""" +"""Implementation of the LeNet network.""" from typing import Callable, Dict, Optional, Tuple from einops.layers.torch import Rearrange @@ -9,7 +9,7 @@ from text_recognizer.networks.misc import activation_function class LeNet(nn.Module): - """LeNet network.""" + """LeNet network for character prediction.""" def __init__( self, @@ -17,10 +17,10 @@ class LeNet(nn.Module): kernel_sizes: Tuple[int, ...] = (3, 3, 2), hidden_size: Tuple[int, ...] = (9216, 128), dropout_rate: float = 0.2, - output_size: int = 10, + num_classes: int = 10, activation_fn: Optional[str] = "relu", ) -> None: - """The LeNet network. + """Initialization of the LeNet network. Args: channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). @@ -28,7 +28,7 @@ class LeNet(nn.Module): hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers. Defaults to (9216, 128). dropout_rate (float): The dropout rate. Defaults to 0.2. - output_size (int): Number of classes. Defaults to 10. + num_classes (int): Number of classes. Defaults to 10. activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu. """ @@ -55,7 +55,7 @@ class LeNet(nn.Module): nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]), activation_fn, nn.Dropout(p=dropout_rate), - nn.Linear(in_features=hidden_size[1], out_features=output_size), + nn.Linear(in_features=hidden_size[1], out_features=num_classes), ] self.layers = nn.Sequential(*self.layers) diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py index 2e2c3a5..988b615 100644 --- a/src/text_recognizer/networks/line_lstm_ctc.py +++ b/src/text_recognizer/networks/line_lstm_ctc.py @@ -1,5 +1,81 @@ """LSTM with CTC for handwritten text recognition within a line.""" +import importlib +from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from einops import rearrange, reduce +from einops.layers.torch import Rearrange, Reduce import torch from torch import nn from torch import Tensor + + +class LineRecurrentNetwork(nn.Module): + """Network that takes a image of a text line and predicts tokens that are in the image.""" + + def __init__( + self, + encoder: str, + encoder_args: Dict = None, + flatten: bool = True, + input_size: int = 128, + hidden_size: int = 128, + num_layers: int = 1, + num_classes: int = 80, + patch_size: Tuple[int, int] = (28, 28), + stride: Tuple[int, int] = (1, 14), + ) -> None: + super().__init__() + self.encoder_args = encoder_args or {} + self.patch_size = patch_size + self.stride = stride + self.sliding_window = self._configure_sliding_window() + self.input_size = input_size + self.hidden_size = hidden_size + self.encoder = self._configure_encoder(encoder) + self.flatten = flatten + self.rnn = nn.LSTM( + input_size=self.input_size, + hidden_size=self.hidden_size, + num_layers=num_layers, + ) + self.decoder = nn.Sequential( + nn.Linear(in_features=self.hidden_size, out_features=num_classes), + nn.LogSoftmax(dim=2), + ) + + def _configure_encoder(self, encoder: str) -> Type[nn.Module]: + network_module = importlib.import_module("text_recognizer.networks") + encoder_ = getattr(network_module, encoder) + return encoder_(**self.encoder_args) + + def _configure_sliding_window(self) -> nn.Sequential: + return nn.Sequential( + nn.Unfold(kernel_size=self.patch_size, stride=self.stride), + Rearrange( + "b (c h w) t -> b t c h w", + h=self.patch_size[0], + w=self.patch_size[1], + c=1, + ), + ) + + def forward(self, x: Tensor) -> Tensor: + """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM.""" + if len(x.shape) == 3: + x = x.unsqueeze(0) + x = self.sliding_window(x) + + # Rearrange from a sequence of patches for feedforward network. + b, t = x.shape[:2] + x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) + x = self.encoder(x) + + # Avgerage pooling. + x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) if self.flatten else x + + # Sequence predictions. + x, _ = self.rnn(x) + + # Sequence to classifcation layer. + x = self.decoder(x) + return x diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py index 6f61b5d..cac9e78 100644 --- a/src/text_recognizer/networks/misc.py +++ b/src/text_recognizer/networks/misc.py @@ -22,9 +22,10 @@ def sliding_window( """ unfold = nn.Unfold(kernel_size=patch_size, stride=stride) # Preform the slidning window, unsqueeze as the channel dimesion is lost. - patches = unfold(images).unsqueeze(1) + c = images.shape[1] + patches = unfold(images) patches = rearrange( - patches, "b c (h w) t -> b t c h w", h=patch_size[0], w=patch_size[1] + patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1] ) return patches diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py index acebdaa..d66af28 100644 --- a/src/text_recognizer/networks/mlp.py +++ b/src/text_recognizer/networks/mlp.py @@ -14,7 +14,7 @@ class MLP(nn.Module): def __init__( self, input_size: int = 784, - output_size: int = 10, + num_classes: int = 10, hidden_size: Union[int, List] = 128, num_layers: int = 3, dropout_rate: float = 0.2, @@ -24,7 +24,7 @@ class MLP(nn.Module): Args: input_size (int): The input shape of the network. Defaults to 784. - output_size (int): Number of classes in the dataset. Defaults to 10. + num_classes (int): Number of classes in the dataset. Defaults to 10. hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. num_layers (int): The number of hidden layers. Defaults to 3. dropout_rate (float): The dropout rate at each layer. Defaults to 0.2. @@ -55,7 +55,7 @@ class MLP(nn.Module): self.layers.append(nn.Dropout(p=dropout_rate)) self.layers.append( - nn.Linear(in_features=hidden_size[-1], out_features=output_size) + nn.Linear(in_features=hidden_size[-1], out_features=num_classes) ) self.layers = nn.Sequential(*self.layers) diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index 47e351a..1b5d6b3 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -8,6 +8,7 @@ from torch import nn from torch import Tensor from text_recognizer.networks.misc import activation_function +from text_recognizer.networks.stn import SpatialTransformerNetwork class Conv2dAuto(nn.Conv2d): @@ -197,25 +198,28 @@ class ResidualLayer(nn.Module): return x -class Encoder(nn.Module): +class ResidualNetworkEncoder(nn.Module): """Encoder network.""" def __init__( self, in_channels: int = 1, - block_sizes: List[int] = (32, 64), - depths: List[int] = (2, 2), + block_sizes: Union[int, List[int]] = (32, 64), + depths: Union[int, List[int]] = (2, 2), activation: str = "relu", block: Type[nn.Module] = BasicBlock, + levels: int = 1, + stn: bool = False, *args, **kwargs ) -> None: super().__init__() - - self.block_sizes = block_sizes - self.depths = depths + self.stn = SpatialTransformerNetwork() if stn else None + self.block_sizes = ( + block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels + ) + self.depths = depths if isinstance(depths, list) else [depths] * levels self.activation = activation - self.gate = nn.Sequential( nn.Conv2d( in_channels=in_channels, @@ -227,7 +231,7 @@ class Encoder(nn.Module): ), nn.BatchNorm2d(self.block_sizes[0]), activation_function(self.activation), - nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + nn.MaxPool2d(kernel_size=2, stride=2, padding=1), ) self.blocks = self._configure_blocks(block) @@ -271,11 +275,13 @@ class Encoder(nn.Module): # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) + if self.stn is not None: + x = self.stn(x) x = self.gate(x) return self.blocks(x) -class Decoder(nn.Module): +class ResidualNetworkDecoder(nn.Module): """Classification head.""" def __init__(self, in_features: int, num_classes: int = 80) -> None: @@ -295,19 +301,12 @@ class ResidualNetwork(nn.Module): def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None: super().__init__() - self.encoder = Encoder(in_channels, *args, **kwargs) - self.decoder = Decoder( + self.encoder = ResidualNetworkEncoder(in_channels, *args, **kwargs) + self.decoder = ResidualNetworkDecoder( in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels, num_classes=num_classes, ) - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - def forward(self, x: Tensor) -> Tensor: """Forward pass.""" x = self.encoder(x) diff --git a/src/text_recognizer/networks/stn.py b/src/text_recognizer/networks/stn.py new file mode 100644 index 0000000..b031128 --- /dev/null +++ b/src/text_recognizer/networks/stn.py @@ -0,0 +1,44 @@ +"""Spatial Transformer Network.""" + +from einops.layers.torch import Rearrange +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F + + +class SpatialTransformerNetwork(nn.Module): + """A network with differentiable attention. + + Network that learns how to perform spatial transformations on the input image in order to enhance the + geometric invariance of the model. + + # TODO: add arguements to make it more general. + + """ + + def __init__(self) -> None: + super().__init__() + # Initialize the identity transformation and its weights and biases. + linear = nn.Linear(32, 3 * 2) + linear.weight.data.zero_() + linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) + + self.theta = nn.Sequential( + nn.Conv2d(in_channels=1, out_channels=8, kernel_size=7), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels=8, out_channels=10, kernel_size=5), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.ReLU(inplace=True), + Rearrange("b c h w -> b (c h w)", h=3, w=3), + nn.Linear(in_features=10 * 3 * 3, out_features=32), + nn.ReLU(inplace=True), + linear, + Rearrange("b (row col) -> b row col", row=2, col=3), + ) + + def forward(self, x: Tensor) -> Tensor: + """The spatial transformation.""" + grid = F.affine_grid(self.theta(x), x.shape) + return F.grid_sample(x, grid, align_corners=False) diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py new file mode 100644 index 0000000..d1c8f9a --- /dev/null +++ b/src/text_recognizer/networks/wide_resnet.py @@ -0,0 +1,214 @@ +"""Wide Residual CNN.""" +from functools import partial +from typing import Callable, Dict, List, Optional, Type, Union + +from einops.layers.torch import Rearrange, Reduce +import numpy as np +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.misc import activation_function + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """Helper function for a 3x3 2d convolution.""" + return nn.Conv2d( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + ) + + +def conv_init(module: Type[nn.Module]) -> None: + """Initializes the weights for convolution and batchnorms.""" + classname = module.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2)) + nn.init.constant(module.bias, 0) + elif classname.find("BatchNorm") != -1: + nn.init.constant(module.weight, 1) + nn.init.constant(module.bias, 0) + + +class WideBlock(nn.Module): + """Block used in WideResNet.""" + + def __init__( + self, + in_planes: int, + out_planes: int, + dropout_rate: float, + stride: int = 1, + activation: str = "relu", + ) -> None: + super().__init__() + self.in_planes = in_planes + self.out_planes = out_planes + self.dropout_rate = dropout_rate + self.stride = stride + self.activation = activation_function(activation) + + # Build blocks. + self.blocks = nn.Sequential( + nn.BatchNorm2d(self.in_planes), + self.activation, + conv3x3(in_planes=self.in_planes, out_planes=self.out_planes), + nn.Dropout(p=self.dropout_rate), + nn.BatchNorm2d(self.out_planes), + self.activation, + conv3x3( + in_planes=self.out_planes, + out_planes=self.out_planes, + stride=self.stride, + ), + ) + + self.shortcut = ( + nn.Sequential( + nn.Conv2d( + in_channels=self.in_planes, + out_channels=self.out_planes, + kernel_size=1, + stride=self.stride, + bias=False, + ), + ) + if self._apply_shortcut + else None + ) + + @property + def _apply_shortcut(self) -> bool: + """If shortcut should be applied or not.""" + return self.stride != 1 or self.in_planes != self.out_planes + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + residual = x + if self._apply_shortcut: + residual = self.shortcut(x) + x = self.blocks(x) + x += residual + return x + + +class WideResidualNetwork(nn.Module): + """WideResNet for character predictions. + + Can be used for classification or encoding of images to a latent vector. + + """ + + def __init__( + self, + in_channels: int = 1, + in_planes: int = 16, + num_classes: int = 80, + depth: int = 16, + width_factor: int = 10, + dropout_rate: float = 0.0, + num_layers: int = 3, + block: Type[nn.Module] = WideBlock, + activation: str = "relu", + use_decoder: bool = True, + ) -> None: + """The initialization of the WideResNet. + + Args: + in_channels (int): Number of input channels. Defaults to 1. + in_planes (int): Number of channels to use in the first output kernel. Defaults to 16. + num_classes (int): Number of classes. Defaults to 80. + depth (int): Set the number of blocks to use. Defaults to 16. + width_factor (int): Factor for scaling the number of channels in the network. Defaults to 10. + dropout_rate (float): The dropout rate. Defaults to 0.0. + num_layers (int): Number of layers of blocks. Defaults to 3. + block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock. + activation (str): Name of the activation to use. Defaults to "relu". + use_decoder (bool): If True, the network output character predictions, if False, the network outputs a + latent vector. Defaults to True. + + Raises: + RuntimeError: If the depth is not of the size `6n+4`. + + """ + + super().__init__() + if (depth - 4) % 6 != 0: + raise RuntimeError("Wide-resnet depth should be 6n+4") + self.in_channels = in_channels + self.in_planes = in_planes + self.num_classes = num_classes + self.num_blocks = (depth - 4) // 6 + self.width_factor = width_factor + self.num_layers = num_layers + self.block = block + self.dropout_rate = dropout_rate + self.activation = activation_function(activation) + + self.num_stages = [self.in_planes] + [ + self.in_planes * 2 ** n * self.width_factor for n in range(self.num_layers) + ] + self.num_stages = list(zip(self.num_stages, self.num_stages[1:])) + self.strides = [1] + [2] * (self.num_layers - 1) + + self.encoder = nn.Sequential( + conv3x3(in_planes=self.in_channels, out_planes=self.in_planes), + *[ + self._configure_wide_layer( + in_planes=in_planes, + out_planes=out_planes, + stride=stride, + activation=activation, + ) + for (in_planes, out_planes), stride in zip( + self.num_stages, self.strides + ) + ], + ) + + self.decoder = ( + nn.Sequential( + nn.BatchNorm2d(self.num_stages[-1][-1], momentum=0.8), + self.activation, + Reduce("b c h w -> b c", "mean"), + nn.Linear( + in_features=self.num_stages[-1][-1], out_features=self.num_classes + ), + ) + if use_decoder + else None + ) + + self.apply(conv_init) + + def _configure_wide_layer( + self, in_planes: int, out_planes: int, stride: int, activation: str + ) -> List: + strides = [stride] + [1] * (self.num_blocks - 1) + planes = [out_planes] * len(strides) + planes = [(in_planes, out_planes)] + list(zip(planes, planes[1:])) + return nn.Sequential( + *[ + self.block( + in_planes=in_planes, + out_planes=out_planes, + dropout_rate=self.dropout_rate, + stride=stride, + activation=activation, + ) + for (in_planes, out_planes), stride in zip(planes, strides) + ] + ) + + def forward(self, x: Tensor) -> Tensor: + """Feedforward pass.""" + if len(x.shape) == 3: + x = x.unsqueeze(0) + x = self.encoder(x) + if self.decoder is not None: + x = self.decoder(x) + return x |