summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r--src/text_recognizer/networks/__init__.py3
-rw-r--r--src/text_recognizer/networks/lenet.py17
-rw-r--r--src/text_recognizer/networks/misc.py20
-rw-r--r--src/text_recognizer/networks/mlp.py18
-rw-r--r--src/text_recognizer/networks/residual_network.py314
5 files changed, 346 insertions, 26 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index e6b6946..a83ca35 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -1,5 +1,6 @@
"""Network modules."""
from .lenet import LeNet
from .mlp import MLP
+from .residual_network import ResidualNetwork
-__all__ = ["MLP", "LeNet"]
+__all__ = ["MLP", "LeNet", "ResidualNetwork"]
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py
index cbc58fc..91d3f2c 100644
--- a/src/text_recognizer/networks/lenet.py
+++ b/src/text_recognizer/networks/lenet.py
@@ -5,6 +5,8 @@ from einops.layers.torch import Rearrange
import torch
from torch import nn
+from text_recognizer.networks.misc import activation_function
+
class LeNet(nn.Module):
"""LeNet network."""
@@ -16,8 +18,7 @@ class LeNet(nn.Module):
hidden_size: Tuple[int, ...] = (9216, 128),
dropout_rate: float = 0.2,
output_size: int = 10,
- activation_fn: Optional[Callable] = None,
- activation_fn_args: Optional[Dict] = None,
+ activation_fn: Optional[str] = "relu",
) -> None:
"""The LeNet network.
@@ -28,18 +29,12 @@ class LeNet(nn.Module):
Defaults to (9216, 128).
dropout_rate (float): The dropout rate. Defaults to 0.2.
output_size (int): Number of classes. Defaults to 10.
- activation_fn (Optional[Callable]): The non-linear activation function. Defaults to
- nn.ReLU(inplace).
- activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None.
+ activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu.
"""
super().__init__()
- if activation_fn is not None:
- activation_fn_args = activation_fn_args or {}
- activation_fn = getattr(nn, activation_fn)(**activation_fn_args)
- else:
- activation_fn = nn.ReLU(inplace=True)
+ activation_fn = activation_function(activation_fn)
self.layers = [
nn.Conv2d(
@@ -66,7 +61,7 @@ class LeNet(nn.Module):
self.layers = nn.Sequential(*self.layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
- """The feedforward."""
+ """The feedforward pass."""
# If batch dimenstion is missing, it needs to be added.
if len(x.shape) == 3:
x = x.unsqueeze(0)
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py
index 2fbab8f..6f61b5d 100644
--- a/src/text_recognizer/networks/misc.py
+++ b/src/text_recognizer/networks/misc.py
@@ -1,9 +1,9 @@
"""Miscellaneous neural network functionality."""
-from typing import Tuple
+from typing import Tuple, Type
from einops import rearrange
import torch
-from torch.nn import Unfold
+from torch import nn
def sliding_window(
@@ -20,10 +20,24 @@ def sliding_window(
torch.Tensor: A tensor with the shape (batch, patches, height, width).
"""
- unfold = Unfold(kernel_size=patch_size, stride=stride)
+ unfold = nn.Unfold(kernel_size=patch_size, stride=stride)
# Preform the slidning window, unsqueeze as the channel dimesion is lost.
patches = unfold(images).unsqueeze(1)
patches = rearrange(
patches, "b c (h w) t -> b t c h w", h=patch_size[0], w=patch_size[1]
)
return patches
+
+
+def activation_function(activation: str) -> Type[nn.Module]:
+ """Returns the callable activation function."""
+ activation_fns = nn.ModuleDict(
+ [
+ ["gelu", nn.GELU()],
+ ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
+ ["none", nn.Identity()],
+ ["relu", nn.ReLU(inplace=True)],
+ ["selu", nn.SELU(inplace=True)],
+ ]
+ )
+ return activation_fns[activation.lower()]
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py
index ac2c825..acebdaa 100644
--- a/src/text_recognizer/networks/mlp.py
+++ b/src/text_recognizer/networks/mlp.py
@@ -5,6 +5,8 @@ from einops.layers.torch import Rearrange
import torch
from torch import nn
+from text_recognizer.networks.misc import activation_function
+
class MLP(nn.Module):
"""Multi layered perceptron network."""
@@ -16,8 +18,7 @@ class MLP(nn.Module):
hidden_size: Union[int, List] = 128,
num_layers: int = 3,
dropout_rate: float = 0.2,
- activation_fn: Optional[Callable] = None,
- activation_fn_args: Optional[Dict] = None,
+ activation_fn: str = "relu",
) -> None:
"""Initialization of the MLP network.
@@ -27,18 +28,13 @@ class MLP(nn.Module):
hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128.
num_layers (int): The number of hidden layers. Defaults to 3.
dropout_rate (float): The dropout rate at each layer. Defaults to 0.2.
- activation_fn (Optional[Callable]): The activation function in the hidden layers. Defaults to
- None.
- activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None.
+ activation_fn (str): Name of the activation function in the hidden layers. Defaults to
+ relu.
"""
super().__init__()
- if activation_fn is not None:
- activation_fn_args = activation_fn_args or {}
- activation_fn = getattr(nn, activation_fn)(**activation_fn_args)
- else:
- activation_fn = nn.ReLU(inplace=True)
+ activation_fn = activation_function(activation_fn)
if isinstance(hidden_size, int):
hidden_size = [hidden_size] * num_layers
@@ -65,7 +61,7 @@ class MLP(nn.Module):
self.layers = nn.Sequential(*self.layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
- """The feedforward."""
+ """The feedforward pass."""
# If batch dimenstion is missing, it needs to be added.
if len(x.shape) == 3:
x = x.unsqueeze(0)
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
index 23394b0..47e351a 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/src/text_recognizer/networks/residual_network.py
@@ -1 +1,315 @@
"""Residual CNN."""
+from functools import partial
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from einops.layers.torch import Rearrange, Reduce
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.misc import activation_function
+
+
+class Conv2dAuto(nn.Conv2d):
+ """Convolution with auto padding based on kernel size."""
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
+
+
+def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential:
+ """3x3 convolution with batch norm."""
+ conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,)
+ return nn.Sequential(
+ conv3x3(in_channels, out_channels, *args, **kwargs),
+ nn.BatchNorm2d(out_channels),
+ )
+
+
+class IdentityBlock(nn.Module):
+ """Residual with identity block."""
+
+ def __init__(
+ self, in_channels: int, out_channels: int, activation: str = "relu"
+ ) -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.blocks = nn.Identity()
+ self.activation_fn = activation_function(activation)
+ self.shortcut = nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ residual = x
+ if self.apply_shortcut:
+ residual = self.shortcut(x)
+ x = self.blocks(x)
+ x += residual
+ x = self.activation_fn(x)
+ return x
+
+ @property
+ def apply_shortcut(self) -> bool:
+ """Check if shortcut should be applied."""
+ return self.in_channels != self.out_channels
+
+
+class ResidualBlock(IdentityBlock):
+ """Residual with nonlinear shortcut."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ expansion: int = 1,
+ downsampling: int = 1,
+ *args,
+ **kwargs
+ ) -> None:
+ """Short summary.
+
+ Args:
+ in_channels (int): Number of in channels.
+ out_channels (int): umber of out channels.
+ expansion (int): Expansion factor of the out channels. Defaults to 1.
+ downsampling (int): Downsampling factor used in stride. Defaults to 1.
+ *args (type): Extra arguments.
+ **kwargs (type): Extra key value arguments.
+
+ """
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ self.expansion = expansion
+ self.downsampling = downsampling
+
+ self.shortcut = (
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.expanded_channels,
+ kernel_size=1,
+ stride=self.downsampling,
+ bias=False,
+ ),
+ nn.BatchNorm2d(self.expanded_channels),
+ )
+ if self.apply_shortcut
+ else None
+ )
+
+ @property
+ def expanded_channels(self) -> int:
+ """Computes the expanded output channels."""
+ return self.out_channels * self.expansion
+
+ @property
+ def apply_shortcut(self) -> bool:
+ """Check if shortcut should be applied."""
+ return self.in_channels != self.expanded_channels
+
+
+class BasicBlock(ResidualBlock):
+ """Basic ResNet block."""
+
+ expansion = 1
+
+ def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None:
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ self.blocks = nn.Sequential(
+ conv_bn(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ bias=False,
+ stride=self.downsampling,
+ ),
+ self.activation_fn,
+ conv_bn(
+ in_channels=self.out_channels,
+ out_channels=self.expanded_channels,
+ bias=False,
+ ),
+ )
+
+
+class BottleNeckBlock(ResidualBlock):
+ """Bottleneck block to increase depth while minimizing parameter size."""
+
+ expansion = 4
+
+ def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None:
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ self.blocks = nn.Sequential(
+ conv_bn(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ kernel_size=1,
+ ),
+ self.activation_fn,
+ conv_bn(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels,
+ kernel_size=3,
+ stride=self.downsampling,
+ ),
+ self.activation_fn,
+ conv_bn(
+ in_channels=self.out_channels,
+ out_channels=self.expanded_channels,
+ kernel_size=1,
+ ),
+ )
+
+
+class ResidualLayer(nn.Module):
+ """ResNet layer."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ block: BasicBlock = BasicBlock,
+ num_blocks: int = 1,
+ *args,
+ **kwargs
+ ) -> None:
+ super().__init__()
+ downsampling = 2 if in_channels != out_channels else 1
+ self.blocks = nn.Sequential(
+ block(
+ in_channels, out_channels, *args, **kwargs, downsampling=downsampling
+ ),
+ *[
+ block(
+ out_channels * block.expansion,
+ out_channels,
+ downsampling=1,
+ *args,
+ **kwargs
+ )
+ for _ in range(num_blocks - 1)
+ ]
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ x = self.blocks(x)
+ return x
+
+
+class Encoder(nn.Module):
+ """Encoder network."""
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ block_sizes: List[int] = (32, 64),
+ depths: List[int] = (2, 2),
+ activation: str = "relu",
+ block: Type[nn.Module] = BasicBlock,
+ *args,
+ **kwargs
+ ) -> None:
+ super().__init__()
+
+ self.block_sizes = block_sizes
+ self.depths = depths
+ self.activation = activation
+
+ self.gate = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=self.block_sizes[0],
+ kernel_size=3,
+ stride=2,
+ padding=3,
+ bias=False,
+ ),
+ nn.BatchNorm2d(self.block_sizes[0]),
+ activation_function(self.activation),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ )
+
+ self.blocks = self._configure_blocks(block)
+
+ def _configure_blocks(
+ self, block: Type[nn.Module], *args, **kwargs
+ ) -> nn.Sequential:
+ channels = [self.block_sizes[0]] + list(
+ zip(self.block_sizes, self.block_sizes[1:])
+ )
+ blocks = [
+ ResidualLayer(
+ in_channels=channels[0],
+ out_channels=channels[0],
+ num_blocks=self.depths[0],
+ block=block,
+ activation=self.activation,
+ *args,
+ **kwargs
+ )
+ ]
+ blocks += [
+ ResidualLayer(
+ in_channels=in_channels * block.expansion,
+ out_channels=out_channels,
+ num_blocks=num_blocks,
+ block=block,
+ activation=self.activation,
+ *args,
+ **kwargs
+ )
+ for (in_channels, out_channels), num_blocks in zip(
+ channels[1:], self.depths[1:]
+ )
+ ]
+
+ return nn.Sequential(*blocks)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) == 3:
+ x = x.unsqueeze(0)
+ x = self.gate(x)
+ return self.blocks(x)
+
+
+class Decoder(nn.Module):
+ """Classification head."""
+
+ def __init__(self, in_features: int, num_classes: int = 80) -> None:
+ super().__init__()
+ self.decoder = nn.Sequential(
+ Reduce("b c h w -> b c", "mean"),
+ nn.Linear(in_features=in_features, out_features=num_classes),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ return self.decoder(x)
+
+
+class ResidualNetwork(nn.Module):
+ """Full residual network."""
+
+ def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None:
+ super().__init__()
+ self.encoder = Encoder(in_channels, *args, **kwargs)
+ self.decoder = Decoder(
+ in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels,
+ num_classes=num_classes,
+ )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ x = self.encoder(x)
+ x = self.decoder(x)
+ return x