summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/cnn.py
blob: 1807bb9ecb5602d017d8c50a0815300df4335fef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""Implementation of a simple backbone cnn network."""
from typing import Callable, Dict, Optional, Tuple

from einops.layers.torch import Rearrange
import torch
from torch import nn

from text_recognizer.networks.util import activation_function


class CNN(nn.Module):
    """LeNet network for character prediction."""

    def __init__(
        self,
        channels: Tuple[int, ...] = (1, 32, 64, 128),
        kernel_sizes: Tuple[int, ...] = (4, 4, 4),
        strides: Tuple[int, ...] = (2, 2, 2),
        max_pool_kernel: int = 2,
        dropout_rate: float = 0.2,
        activation: Optional[str] = "relu",
    ) -> None:
        """Initialization of the LeNet network.

        Args:
            channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
            kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).
            strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2).
            max_pool_kernel (int): 2D max pooling kernel. Defaults to 2.
            dropout_rate (float): The dropout rate. Defaults to 0.2.
            activation (Optional[str]): The name of non-linear activation function. Defaults to relu.

        Raises:
            RuntimeError: if the number of hyperparameters does not match in length.

        """
        super().__init__()

        if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides):
            raise RuntimeError("The number of the hyperparameters does not match.")

        self.cnn = self._build_network(
            channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation,
        )

    def _build_network(
        self,
        channels: Tuple[int, ...],
        kernel_sizes: Tuple[int, ...],
        strides: Tuple[int, ...],
        max_pool_kernel: int,
        dropout_rate: float,
        activation: str,
    ) -> nn.Sequential:
        # Load activation function.
        activation_fn = activation_function(activation)

        channels = list(channels)
        in_channels = channels.pop(0)
        configuration = zip(channels, kernel_sizes, strides)

        modules = nn.ModuleList([])

        for i, (out_channels, kernel_size, stride) in enumerate(configuration):
            # Add max pool to reduce output size.
            if i == len(channels) // 2:
                modules.append(nn.MaxPool2d(max_pool_kernel))
            if i == 0:
                modules.append(
                    nn.Conv2d(
                        in_channels, out_channels, kernel_size, stride=stride, padding=1
                    )
                )
            else:
                modules.append(
                    nn.Sequential(
                        activation_fn,
                        nn.BatchNorm2d(in_channels),
                        nn.Conv2d(
                            in_channels,
                            out_channels,
                            kernel_size,
                            stride=stride,
                            padding=1,
                        ),
                    )
                )

            if dropout_rate:
                modules.append(nn.Dropout2d(p=dropout_rate))

            in_channels = out_channels

        return nn.Sequential(*modules)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """The feedforward pass."""
        # If batch dimenstion is missing, it needs to be added.
        if len(x.shape) < 4:
            x = x[(None,) * (4 - len(x.shape))]
        return self.cnn(x)