diff options
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r-- | src/text_recognizer/networks/__init__.py | 3 | ||||
-rw-r--r-- | src/text_recognizer/networks/lenet.py | 17 | ||||
-rw-r--r-- | src/text_recognizer/networks/misc.py | 20 | ||||
-rw-r--r-- | src/text_recognizer/networks/mlp.py | 18 | ||||
-rw-r--r-- | src/text_recognizer/networks/residual_network.py | 314 |
5 files changed, 346 insertions, 26 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index e6b6946..a83ca35 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -1,5 +1,6 @@ """Network modules.""" from .lenet import LeNet from .mlp import MLP +from .residual_network import ResidualNetwork -__all__ = ["MLP", "LeNet"] +__all__ = ["MLP", "LeNet", "ResidualNetwork"] diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py index cbc58fc..91d3f2c 100644 --- a/src/text_recognizer/networks/lenet.py +++ b/src/text_recognizer/networks/lenet.py @@ -5,6 +5,8 @@ from einops.layers.torch import Rearrange import torch from torch import nn +from text_recognizer.networks.misc import activation_function + class LeNet(nn.Module): """LeNet network.""" @@ -16,8 +18,7 @@ class LeNet(nn.Module): hidden_size: Tuple[int, ...] = (9216, 128), dropout_rate: float = 0.2, output_size: int = 10, - activation_fn: Optional[Callable] = None, - activation_fn_args: Optional[Dict] = None, + activation_fn: Optional[str] = "relu", ) -> None: """The LeNet network. @@ -28,18 +29,12 @@ class LeNet(nn.Module): Defaults to (9216, 128). dropout_rate (float): The dropout rate. Defaults to 0.2. output_size (int): Number of classes. Defaults to 10. - activation_fn (Optional[Callable]): The non-linear activation function. Defaults to - nn.ReLU(inplace). - activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None. + activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu. """ super().__init__() - if activation_fn is not None: - activation_fn_args = activation_fn_args or {} - activation_fn = getattr(nn, activation_fn)(**activation_fn_args) - else: - activation_fn = nn.ReLU(inplace=True) + activation_fn = activation_function(activation_fn) self.layers = [ nn.Conv2d( @@ -66,7 +61,7 @@ class LeNet(nn.Module): self.layers = nn.Sequential(*self.layers) def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward.""" + """The feedforward pass.""" # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py index 2fbab8f..6f61b5d 100644 --- a/src/text_recognizer/networks/misc.py +++ b/src/text_recognizer/networks/misc.py @@ -1,9 +1,9 @@ """Miscellaneous neural network functionality.""" -from typing import Tuple +from typing import Tuple, Type from einops import rearrange import torch -from torch.nn import Unfold +from torch import nn def sliding_window( @@ -20,10 +20,24 @@ def sliding_window( torch.Tensor: A tensor with the shape (batch, patches, height, width). """ - unfold = Unfold(kernel_size=patch_size, stride=stride) + unfold = nn.Unfold(kernel_size=patch_size, stride=stride) # Preform the slidning window, unsqueeze as the channel dimesion is lost. patches = unfold(images).unsqueeze(1) patches = rearrange( patches, "b c (h w) t -> b t c h w", h=patch_size[0], w=patch_size[1] ) return patches + + +def activation_function(activation: str) -> Type[nn.Module]: + """Returns the callable activation function.""" + activation_fns = nn.ModuleDict( + [ + ["gelu", nn.GELU()], + ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)], + ["none", nn.Identity()], + ["relu", nn.ReLU(inplace=True)], + ["selu", nn.SELU(inplace=True)], + ] + ) + return activation_fns[activation.lower()] diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py index ac2c825..acebdaa 100644 --- a/src/text_recognizer/networks/mlp.py +++ b/src/text_recognizer/networks/mlp.py @@ -5,6 +5,8 @@ from einops.layers.torch import Rearrange import torch from torch import nn +from text_recognizer.networks.misc import activation_function + class MLP(nn.Module): """Multi layered perceptron network.""" @@ -16,8 +18,7 @@ class MLP(nn.Module): hidden_size: Union[int, List] = 128, num_layers: int = 3, dropout_rate: float = 0.2, - activation_fn: Optional[Callable] = None, - activation_fn_args: Optional[Dict] = None, + activation_fn: str = "relu", ) -> None: """Initialization of the MLP network. @@ -27,18 +28,13 @@ class MLP(nn.Module): hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. num_layers (int): The number of hidden layers. Defaults to 3. dropout_rate (float): The dropout rate at each layer. Defaults to 0.2. - activation_fn (Optional[Callable]): The activation function in the hidden layers. Defaults to - None. - activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None. + activation_fn (str): Name of the activation function in the hidden layers. Defaults to + relu. """ super().__init__() - if activation_fn is not None: - activation_fn_args = activation_fn_args or {} - activation_fn = getattr(nn, activation_fn)(**activation_fn_args) - else: - activation_fn = nn.ReLU(inplace=True) + activation_fn = activation_function(activation_fn) if isinstance(hidden_size, int): hidden_size = [hidden_size] * num_layers @@ -65,7 +61,7 @@ class MLP(nn.Module): self.layers = nn.Sequential(*self.layers) def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward.""" + """The feedforward pass.""" # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index 23394b0..47e351a 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -1 +1,315 @@ """Residual CNN.""" +from functools import partial +from typing import Callable, Dict, List, Optional, Type, Union + +from einops.layers.torch import Rearrange, Reduce +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.misc import activation_function + + +class Conv2dAuto(nn.Conv2d): + """Convolution with auto padding based on kernel size.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2) + + +def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential: + """3x3 convolution with batch norm.""" + conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) + return nn.Sequential( + conv3x3(in_channels, out_channels, *args, **kwargs), + nn.BatchNorm2d(out_channels), + ) + + +class IdentityBlock(nn.Module): + """Residual with identity block.""" + + def __init__( + self, in_channels: int, out_channels: int, activation: str = "relu" + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.blocks = nn.Identity() + self.activation_fn = activation_function(activation) + self.shortcut = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + residual = x + if self.apply_shortcut: + residual = self.shortcut(x) + x = self.blocks(x) + x += residual + x = self.activation_fn(x) + return x + + @property + def apply_shortcut(self) -> bool: + """Check if shortcut should be applied.""" + return self.in_channels != self.out_channels + + +class ResidualBlock(IdentityBlock): + """Residual with nonlinear shortcut.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + expansion: int = 1, + downsampling: int = 1, + *args, + **kwargs + ) -> None: + """Short summary. + + Args: + in_channels (int): Number of in channels. + out_channels (int): umber of out channels. + expansion (int): Expansion factor of the out channels. Defaults to 1. + downsampling (int): Downsampling factor used in stride. Defaults to 1. + *args (type): Extra arguments. + **kwargs (type): Extra key value arguments. + + """ + super().__init__(in_channels, out_channels, *args, **kwargs) + self.expansion = expansion + self.downsampling = downsampling + + self.shortcut = ( + nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.expanded_channels, + kernel_size=1, + stride=self.downsampling, + bias=False, + ), + nn.BatchNorm2d(self.expanded_channels), + ) + if self.apply_shortcut + else None + ) + + @property + def expanded_channels(self) -> int: + """Computes the expanded output channels.""" + return self.out_channels * self.expansion + + @property + def apply_shortcut(self) -> bool: + """Check if shortcut should be applied.""" + return self.in_channels != self.expanded_channels + + +class BasicBlock(ResidualBlock): + """Basic ResNet block.""" + + expansion = 1 + + def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: + super().__init__(in_channels, out_channels, *args, **kwargs) + self.blocks = nn.Sequential( + conv_bn( + in_channels=self.in_channels, + out_channels=self.out_channels, + bias=False, + stride=self.downsampling, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.expanded_channels, + bias=False, + ), + ) + + +class BottleNeckBlock(ResidualBlock): + """Bottleneck block to increase depth while minimizing parameter size.""" + + expansion = 4 + + def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: + super().__init__(in_channels, out_channels, *args, **kwargs) + self.blocks = nn.Sequential( + conv_bn( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=self.downsampling, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.expanded_channels, + kernel_size=1, + ), + ) + + +class ResidualLayer(nn.Module): + """ResNet layer.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + block: BasicBlock = BasicBlock, + num_blocks: int = 1, + *args, + **kwargs + ) -> None: + super().__init__() + downsampling = 2 if in_channels != out_channels else 1 + self.blocks = nn.Sequential( + block( + in_channels, out_channels, *args, **kwargs, downsampling=downsampling + ), + *[ + block( + out_channels * block.expansion, + out_channels, + downsampling=1, + *args, + **kwargs + ) + for _ in range(num_blocks - 1) + ] + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x = self.blocks(x) + return x + + +class Encoder(nn.Module): + """Encoder network.""" + + def __init__( + self, + in_channels: int = 1, + block_sizes: List[int] = (32, 64), + depths: List[int] = (2, 2), + activation: str = "relu", + block: Type[nn.Module] = BasicBlock, + *args, + **kwargs + ) -> None: + super().__init__() + + self.block_sizes = block_sizes + self.depths = depths + self.activation = activation + + self.gate = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=self.block_sizes[0], + kernel_size=3, + stride=2, + padding=3, + bias=False, + ), + nn.BatchNorm2d(self.block_sizes[0]), + activation_function(self.activation), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + + self.blocks = self._configure_blocks(block) + + def _configure_blocks( + self, block: Type[nn.Module], *args, **kwargs + ) -> nn.Sequential: + channels = [self.block_sizes[0]] + list( + zip(self.block_sizes, self.block_sizes[1:]) + ) + blocks = [ + ResidualLayer( + in_channels=channels[0], + out_channels=channels[0], + num_blocks=self.depths[0], + block=block, + activation=self.activation, + *args, + **kwargs + ) + ] + blocks += [ + ResidualLayer( + in_channels=in_channels * block.expansion, + out_channels=out_channels, + num_blocks=num_blocks, + block=block, + activation=self.activation, + *args, + **kwargs + ) + for (in_channels, out_channels), num_blocks in zip( + channels[1:], self.depths[1:] + ) + ] + + return nn.Sequential(*blocks) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) == 3: + x = x.unsqueeze(0) + x = self.gate(x) + return self.blocks(x) + + +class Decoder(nn.Module): + """Classification head.""" + + def __init__(self, in_features: int, num_classes: int = 80) -> None: + super().__init__() + self.decoder = nn.Sequential( + Reduce("b c h w -> b c", "mean"), + nn.Linear(in_features=in_features, out_features=num_classes), + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + return self.decoder(x) + + +class ResidualNetwork(nn.Module): + """Full residual network.""" + + def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None: + super().__init__() + self.encoder = Encoder(in_channels, *args, **kwargs) + self.decoder = Decoder( + in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels, + num_classes=num_classes, + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x = self.encoder(x) + x = self.decoder(x) + return x |