summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r--src/text_recognizer/networks/__init__.py17
-rw-r--r--src/text_recognizer/networks/ctc.py66
-rw-r--r--src/text_recognizer/networks/lenet.py12
-rw-r--r--src/text_recognizer/networks/line_lstm_ctc.py76
-rw-r--r--src/text_recognizer/networks/misc.py5
-rw-r--r--src/text_recognizer/networks/mlp.py6
-rw-r--r--src/text_recognizer/networks/residual_network.py35
-rw-r--r--src/text_recognizer/networks/stn.py44
-rw-r--r--src/text_recognizer/networks/wide_resnet.py214
9 files changed, 435 insertions, 40 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index a83ca35..d20c86a 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -1,6 +1,19 @@
"""Network modules."""
+from .ctc import greedy_decoder
from .lenet import LeNet
+from .line_lstm_ctc import LineRecurrentNetwork
+from .misc import sliding_window
from .mlp import MLP
-from .residual_network import ResidualNetwork
+from .residual_network import ResidualNetwork, ResidualNetworkEncoder
+from .wide_resnet import WideResidualNetwork
-__all__ = ["MLP", "LeNet", "ResidualNetwork"]
+__all__ = [
+ "greedy_decoder",
+ "MLP",
+ "LeNet",
+ "LineRecurrentNetwork",
+ "ResidualNetwork",
+ "ResidualNetworkEncoder",
+ "sliding_window",
+ "WideResidualNetwork",
+]
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py
index 00ad47e..fc0d21d 100644
--- a/src/text_recognizer/networks/ctc.py
+++ b/src/text_recognizer/networks/ctc.py
@@ -1,10 +1,58 @@
"""Decodes the CTC output."""
-#
-# from typing import Tuple
-# import torch
-#
-#
-# def greedy_decoder(
-# output, labels, label_length, blank_label, collapse_repeated=True
-# ) -> Tuple[torch.Tensor, torch.Tensor]:
-# pass
+from typing import Callable, List, Optional, Tuple
+
+from einops import rearrange
+import torch
+from torch import Tensor
+
+from text_recognizer.datasets import EmnistMapper
+
+
+def greedy_decoder(
+ predictions: Tensor,
+ targets: Optional[Tensor] = None,
+ target_lengths: Optional[Tensor] = None,
+ character_mapper: Optional[Callable] = None,
+ blank_label: int = 79,
+ collapse_repeated: bool = True,
+) -> Tuple[List[str], List[str]]:
+ """Greedy CTC decoder.
+
+ Args:
+ predictions (Tensor): Tenor of network predictions, shape [time, batch, classes].
+ targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None.
+ target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None.
+ character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults
+ to None.
+ blank_label (int): The blank character to be ignored. Defaults to 79.
+ collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True.
+
+ Returns:
+ Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets.
+
+ """
+
+ if character_mapper is None:
+ character_mapper = EmnistMapper()
+
+ predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t")
+ decoded_predictions = []
+ decoded_targets = []
+ for i, prediction in enumerate(predictions):
+ decoded_prediction = []
+ decoded_target = []
+ if targets is not None and target_lengths is not None:
+ for target_index in targets[i][: target_lengths[i]]:
+ if target_index == blank_label:
+ continue
+ decoded_target.append(character_mapper(int(target_index)))
+ decoded_targets.append(decoded_target)
+ for j, index in enumerate(prediction):
+ if index != blank_label:
+ if collapse_repeated and j != 0 and index == prediction[j - 1]:
+ continue
+ decoded_prediction.append(index.item())
+ decoded_predictions.append(
+ [character_mapper(int(pred_index)) for pred_index in decoded_prediction]
+ )
+ return decoded_predictions, decoded_targets
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py
index 91d3f2c..53c575e 100644
--- a/src/text_recognizer/networks/lenet.py
+++ b/src/text_recognizer/networks/lenet.py
@@ -1,4 +1,4 @@
-"""Defines the LeNet network."""
+"""Implementation of the LeNet network."""
from typing import Callable, Dict, Optional, Tuple
from einops.layers.torch import Rearrange
@@ -9,7 +9,7 @@ from text_recognizer.networks.misc import activation_function
class LeNet(nn.Module):
- """LeNet network."""
+ """LeNet network for character prediction."""
def __init__(
self,
@@ -17,10 +17,10 @@ class LeNet(nn.Module):
kernel_sizes: Tuple[int, ...] = (3, 3, 2),
hidden_size: Tuple[int, ...] = (9216, 128),
dropout_rate: float = 0.2,
- output_size: int = 10,
+ num_classes: int = 10,
activation_fn: Optional[str] = "relu",
) -> None:
- """The LeNet network.
+ """Initialization of the LeNet network.
Args:
channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
@@ -28,7 +28,7 @@ class LeNet(nn.Module):
hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers.
Defaults to (9216, 128).
dropout_rate (float): The dropout rate. Defaults to 0.2.
- output_size (int): Number of classes. Defaults to 10.
+ num_classes (int): Number of classes. Defaults to 10.
activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu.
"""
@@ -55,7 +55,7 @@ class LeNet(nn.Module):
nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]),
activation_fn,
nn.Dropout(p=dropout_rate),
- nn.Linear(in_features=hidden_size[1], out_features=output_size),
+ nn.Linear(in_features=hidden_size[1], out_features=num_classes),
]
self.layers = nn.Sequential(*self.layers)
diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py
index 2e2c3a5..988b615 100644
--- a/src/text_recognizer/networks/line_lstm_ctc.py
+++ b/src/text_recognizer/networks/line_lstm_ctc.py
@@ -1,5 +1,81 @@
"""LSTM with CTC for handwritten text recognition within a line."""
+import importlib
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+from einops import rearrange, reduce
+from einops.layers.torch import Rearrange, Reduce
import torch
from torch import nn
from torch import Tensor
+
+
+class LineRecurrentNetwork(nn.Module):
+ """Network that takes a image of a text line and predicts tokens that are in the image."""
+
+ def __init__(
+ self,
+ encoder: str,
+ encoder_args: Dict = None,
+ flatten: bool = True,
+ input_size: int = 128,
+ hidden_size: int = 128,
+ num_layers: int = 1,
+ num_classes: int = 80,
+ patch_size: Tuple[int, int] = (28, 28),
+ stride: Tuple[int, int] = (1, 14),
+ ) -> None:
+ super().__init__()
+ self.encoder_args = encoder_args or {}
+ self.patch_size = patch_size
+ self.stride = stride
+ self.sliding_window = self._configure_sliding_window()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.encoder = self._configure_encoder(encoder)
+ self.flatten = flatten
+ self.rnn = nn.LSTM(
+ input_size=self.input_size,
+ hidden_size=self.hidden_size,
+ num_layers=num_layers,
+ )
+ self.decoder = nn.Sequential(
+ nn.Linear(in_features=self.hidden_size, out_features=num_classes),
+ nn.LogSoftmax(dim=2),
+ )
+
+ def _configure_encoder(self, encoder: str) -> Type[nn.Module]:
+ network_module = importlib.import_module("text_recognizer.networks")
+ encoder_ = getattr(network_module, encoder)
+ return encoder_(**self.encoder_args)
+
+ def _configure_sliding_window(self) -> nn.Sequential:
+ return nn.Sequential(
+ nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
+ Rearrange(
+ "b (c h w) t -> b t c h w",
+ h=self.patch_size[0],
+ w=self.patch_size[1],
+ c=1,
+ ),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
+ if len(x.shape) == 3:
+ x = x.unsqueeze(0)
+ x = self.sliding_window(x)
+
+ # Rearrange from a sequence of patches for feedforward network.
+ b, t = x.shape[:2]
+ x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
+ x = self.encoder(x)
+
+ # Avgerage pooling.
+ x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) if self.flatten else x
+
+ # Sequence predictions.
+ x, _ = self.rnn(x)
+
+ # Sequence to classifcation layer.
+ x = self.decoder(x)
+ return x
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py
index 6f61b5d..cac9e78 100644
--- a/src/text_recognizer/networks/misc.py
+++ b/src/text_recognizer/networks/misc.py
@@ -22,9 +22,10 @@ def sliding_window(
"""
unfold = nn.Unfold(kernel_size=patch_size, stride=stride)
# Preform the slidning window, unsqueeze as the channel dimesion is lost.
- patches = unfold(images).unsqueeze(1)
+ c = images.shape[1]
+ patches = unfold(images)
patches = rearrange(
- patches, "b c (h w) t -> b t c h w", h=patch_size[0], w=patch_size[1]
+ patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1]
)
return patches
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py
index acebdaa..d66af28 100644
--- a/src/text_recognizer/networks/mlp.py
+++ b/src/text_recognizer/networks/mlp.py
@@ -14,7 +14,7 @@ class MLP(nn.Module):
def __init__(
self,
input_size: int = 784,
- output_size: int = 10,
+ num_classes: int = 10,
hidden_size: Union[int, List] = 128,
num_layers: int = 3,
dropout_rate: float = 0.2,
@@ -24,7 +24,7 @@ class MLP(nn.Module):
Args:
input_size (int): The input shape of the network. Defaults to 784.
- output_size (int): Number of classes in the dataset. Defaults to 10.
+ num_classes (int): Number of classes in the dataset. Defaults to 10.
hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128.
num_layers (int): The number of hidden layers. Defaults to 3.
dropout_rate (float): The dropout rate at each layer. Defaults to 0.2.
@@ -55,7 +55,7 @@ class MLP(nn.Module):
self.layers.append(nn.Dropout(p=dropout_rate))
self.layers.append(
- nn.Linear(in_features=hidden_size[-1], out_features=output_size)
+ nn.Linear(in_features=hidden_size[-1], out_features=num_classes)
)
self.layers = nn.Sequential(*self.layers)
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
index 47e351a..1b5d6b3 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/src/text_recognizer/networks/residual_network.py
@@ -8,6 +8,7 @@ from torch import nn
from torch import Tensor
from text_recognizer.networks.misc import activation_function
+from text_recognizer.networks.stn import SpatialTransformerNetwork
class Conv2dAuto(nn.Conv2d):
@@ -197,25 +198,28 @@ class ResidualLayer(nn.Module):
return x
-class Encoder(nn.Module):
+class ResidualNetworkEncoder(nn.Module):
"""Encoder network."""
def __init__(
self,
in_channels: int = 1,
- block_sizes: List[int] = (32, 64),
- depths: List[int] = (2, 2),
+ block_sizes: Union[int, List[int]] = (32, 64),
+ depths: Union[int, List[int]] = (2, 2),
activation: str = "relu",
block: Type[nn.Module] = BasicBlock,
+ levels: int = 1,
+ stn: bool = False,
*args,
**kwargs
) -> None:
super().__init__()
-
- self.block_sizes = block_sizes
- self.depths = depths
+ self.stn = SpatialTransformerNetwork() if stn else None
+ self.block_sizes = (
+ block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels
+ )
+ self.depths = depths if isinstance(depths, list) else [depths] * levels
self.activation = activation
-
self.gate = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
@@ -227,7 +231,7 @@ class Encoder(nn.Module):
),
nn.BatchNorm2d(self.block_sizes[0]),
activation_function(self.activation),
- nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
)
self.blocks = self._configure_blocks(block)
@@ -271,11 +275,13 @@ class Encoder(nn.Module):
# If batch dimenstion is missing, it needs to be added.
if len(x.shape) == 3:
x = x.unsqueeze(0)
+ if self.stn is not None:
+ x = self.stn(x)
x = self.gate(x)
return self.blocks(x)
-class Decoder(nn.Module):
+class ResidualNetworkDecoder(nn.Module):
"""Classification head."""
def __init__(self, in_features: int, num_classes: int = 80) -> None:
@@ -295,19 +301,12 @@ class ResidualNetwork(nn.Module):
def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None:
super().__init__()
- self.encoder = Encoder(in_channels, *args, **kwargs)
- self.decoder = Decoder(
+ self.encoder = ResidualNetworkEncoder(in_channels, *args, **kwargs)
+ self.decoder = ResidualNetworkDecoder(
in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels,
num_classes=num_classes,
)
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
def forward(self, x: Tensor) -> Tensor:
"""Forward pass."""
x = self.encoder(x)
diff --git a/src/text_recognizer/networks/stn.py b/src/text_recognizer/networks/stn.py
new file mode 100644
index 0000000..b031128
--- /dev/null
+++ b/src/text_recognizer/networks/stn.py
@@ -0,0 +1,44 @@
+"""Spatial Transformer Network."""
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+
+
+class SpatialTransformerNetwork(nn.Module):
+ """A network with differentiable attention.
+
+ Network that learns how to perform spatial transformations on the input image in order to enhance the
+ geometric invariance of the model.
+
+ # TODO: add arguements to make it more general.
+
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ # Initialize the identity transformation and its weights and biases.
+ linear = nn.Linear(32, 3 * 2)
+ linear.weight.data.zero_()
+ linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
+
+ self.theta = nn.Sequential(
+ nn.Conv2d(in_channels=1, out_channels=8, kernel_size=7),
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=8, out_channels=10, kernel_size=5),
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ nn.ReLU(inplace=True),
+ Rearrange("b c h w -> b (c h w)", h=3, w=3),
+ nn.Linear(in_features=10 * 3 * 3, out_features=32),
+ nn.ReLU(inplace=True),
+ linear,
+ Rearrange("b (row col) -> b row col", row=2, col=3),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """The spatial transformation."""
+ grid = F.affine_grid(self.theta(x), x.shape)
+ return F.grid_sample(x, grid, align_corners=False)
diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py
new file mode 100644
index 0000000..d1c8f9a
--- /dev/null
+++ b/src/text_recognizer/networks/wide_resnet.py
@@ -0,0 +1,214 @@
+"""Wide Residual CNN."""
+from functools import partial
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from einops.layers.torch import Rearrange, Reduce
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.misc import activation_function
+
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
+ """Helper function for a 3x3 2d convolution."""
+ return nn.Conv2d(
+ in_channels=in_planes,
+ out_channels=out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False,
+ )
+
+
+def conv_init(module: Type[nn.Module]) -> None:
+ """Initializes the weights for convolution and batchnorms."""
+ classname = module.__class__.__name__
+ if classname.find("Conv") != -1:
+ nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2))
+ nn.init.constant(module.bias, 0)
+ elif classname.find("BatchNorm") != -1:
+ nn.init.constant(module.weight, 1)
+ nn.init.constant(module.bias, 0)
+
+
+class WideBlock(nn.Module):
+ """Block used in WideResNet."""
+
+ def __init__(
+ self,
+ in_planes: int,
+ out_planes: int,
+ dropout_rate: float,
+ stride: int = 1,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.in_planes = in_planes
+ self.out_planes = out_planes
+ self.dropout_rate = dropout_rate
+ self.stride = stride
+ self.activation = activation_function(activation)
+
+ # Build blocks.
+ self.blocks = nn.Sequential(
+ nn.BatchNorm2d(self.in_planes),
+ self.activation,
+ conv3x3(in_planes=self.in_planes, out_planes=self.out_planes),
+ nn.Dropout(p=self.dropout_rate),
+ nn.BatchNorm2d(self.out_planes),
+ self.activation,
+ conv3x3(
+ in_planes=self.out_planes,
+ out_planes=self.out_planes,
+ stride=self.stride,
+ ),
+ )
+
+ self.shortcut = (
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.in_planes,
+ out_channels=self.out_planes,
+ kernel_size=1,
+ stride=self.stride,
+ bias=False,
+ ),
+ )
+ if self._apply_shortcut
+ else None
+ )
+
+ @property
+ def _apply_shortcut(self) -> bool:
+ """If shortcut should be applied or not."""
+ return self.stride != 1 or self.in_planes != self.out_planes
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ residual = x
+ if self._apply_shortcut:
+ residual = self.shortcut(x)
+ x = self.blocks(x)
+ x += residual
+ return x
+
+
+class WideResidualNetwork(nn.Module):
+ """WideResNet for character predictions.
+
+ Can be used for classification or encoding of images to a latent vector.
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ in_planes: int = 16,
+ num_classes: int = 80,
+ depth: int = 16,
+ width_factor: int = 10,
+ dropout_rate: float = 0.0,
+ num_layers: int = 3,
+ block: Type[nn.Module] = WideBlock,
+ activation: str = "relu",
+ use_decoder: bool = True,
+ ) -> None:
+ """The initialization of the WideResNet.
+
+ Args:
+ in_channels (int): Number of input channels. Defaults to 1.
+ in_planes (int): Number of channels to use in the first output kernel. Defaults to 16.
+ num_classes (int): Number of classes. Defaults to 80.
+ depth (int): Set the number of blocks to use. Defaults to 16.
+ width_factor (int): Factor for scaling the number of channels in the network. Defaults to 10.
+ dropout_rate (float): The dropout rate. Defaults to 0.0.
+ num_layers (int): Number of layers of blocks. Defaults to 3.
+ block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock.
+ activation (str): Name of the activation to use. Defaults to "relu".
+ use_decoder (bool): If True, the network output character predictions, if False, the network outputs a
+ latent vector. Defaults to True.
+
+ Raises:
+ RuntimeError: If the depth is not of the size `6n+4`.
+
+ """
+
+ super().__init__()
+ if (depth - 4) % 6 != 0:
+ raise RuntimeError("Wide-resnet depth should be 6n+4")
+ self.in_channels = in_channels
+ self.in_planes = in_planes
+ self.num_classes = num_classes
+ self.num_blocks = (depth - 4) // 6
+ self.width_factor = width_factor
+ self.num_layers = num_layers
+ self.block = block
+ self.dropout_rate = dropout_rate
+ self.activation = activation_function(activation)
+
+ self.num_stages = [self.in_planes] + [
+ self.in_planes * 2 ** n * self.width_factor for n in range(self.num_layers)
+ ]
+ self.num_stages = list(zip(self.num_stages, self.num_stages[1:]))
+ self.strides = [1] + [2] * (self.num_layers - 1)
+
+ self.encoder = nn.Sequential(
+ conv3x3(in_planes=self.in_channels, out_planes=self.in_planes),
+ *[
+ self._configure_wide_layer(
+ in_planes=in_planes,
+ out_planes=out_planes,
+ stride=stride,
+ activation=activation,
+ )
+ for (in_planes, out_planes), stride in zip(
+ self.num_stages, self.strides
+ )
+ ],
+ )
+
+ self.decoder = (
+ nn.Sequential(
+ nn.BatchNorm2d(self.num_stages[-1][-1], momentum=0.8),
+ self.activation,
+ Reduce("b c h w -> b c", "mean"),
+ nn.Linear(
+ in_features=self.num_stages[-1][-1], out_features=self.num_classes
+ ),
+ )
+ if use_decoder
+ else None
+ )
+
+ self.apply(conv_init)
+
+ def _configure_wide_layer(
+ self, in_planes: int, out_planes: int, stride: int, activation: str
+ ) -> List:
+ strides = [stride] + [1] * (self.num_blocks - 1)
+ planes = [out_planes] * len(strides)
+ planes = [(in_planes, out_planes)] + list(zip(planes, planes[1:]))
+ return nn.Sequential(
+ *[
+ self.block(
+ in_planes=in_planes,
+ out_planes=out_planes,
+ dropout_rate=self.dropout_rate,
+ stride=stride,
+ activation=activation,
+ )
+ for (in_planes, out_planes), stride in zip(planes, strides)
+ ]
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Feedforward pass."""
+ if len(x.shape) == 3:
+ x = x.unsqueeze(0)
+ x = self.encoder(x)
+ if self.decoder is not None:
+ x = self.decoder(x)
+ return x