diff options
Diffstat (limited to 'src/text_recognizer/networks/cnn.py')
-rw-r--r-- | src/text_recognizer/networks/cnn.py | 101 |
1 files changed, 101 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/cnn.py b/src/text_recognizer/networks/cnn.py new file mode 100644 index 0000000..1807bb9 --- /dev/null +++ b/src/text_recognizer/networks/cnn.py @@ -0,0 +1,101 @@ +"""Implementation of a simple backbone cnn network.""" +from typing import Callable, Dict, Optional, Tuple + +from einops.layers.torch import Rearrange +import torch +from torch import nn + +from text_recognizer.networks.util import activation_function + + +class CNN(nn.Module): + """LeNet network for character prediction.""" + + def __init__( + self, + channels: Tuple[int, ...] = (1, 32, 64, 128), + kernel_sizes: Tuple[int, ...] = (4, 4, 4), + strides: Tuple[int, ...] = (2, 2, 2), + max_pool_kernel: int = 2, + dropout_rate: float = 0.2, + activation: Optional[str] = "relu", + ) -> None: + """Initialization of the LeNet network. + + Args: + channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). + kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). + strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2). + max_pool_kernel (int): 2D max pooling kernel. Defaults to 2. + dropout_rate (float): The dropout rate. Defaults to 0.2. + activation (Optional[str]): The name of non-linear activation function. Defaults to relu. + + Raises: + RuntimeError: if the number of hyperparameters does not match in length. + + """ + super().__init__() + + if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides): + raise RuntimeError("The number of the hyperparameters does not match.") + + self.cnn = self._build_network( + channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation, + ) + + def _build_network( + self, + channels: Tuple[int, ...], + kernel_sizes: Tuple[int, ...], + strides: Tuple[int, ...], + max_pool_kernel: int, + dropout_rate: float, + activation: str, + ) -> nn.Sequential: + # Load activation function. + activation_fn = activation_function(activation) + + channels = list(channels) + in_channels = channels.pop(0) + configuration = zip(channels, kernel_sizes, strides) + + modules = nn.ModuleList([]) + + for i, (out_channels, kernel_size, stride) in enumerate(configuration): + # Add max pool to reduce output size. + if i == len(channels) // 2: + modules.append(nn.MaxPool2d(max_pool_kernel)) + if i == 0: + modules.append( + nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=1 + ) + ) + else: + modules.append( + nn.Sequential( + activation_fn, + nn.BatchNorm2d(in_channels), + nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=1, + ), + ) + ) + + if dropout_rate: + modules.append(nn.Dropout2d(p=dropout_rate)) + + in_channels = out_channels + + return nn.Sequential(*modules) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """The feedforward pass.""" + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + return self.cnn(x) |