1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
|
"""Implementation of a simple backbone cnn network."""
from typing import Callable, Dict, Optional, Tuple
from einops.layers.torch import Rearrange
import torch
from torch import nn
from text_recognizer.networks.util import activation_function
class CNN(nn.Module):
"""LeNet network for character prediction."""
def __init__(
self,
channels: Tuple[int, ...] = (1, 32, 64, 128),
kernel_sizes: Tuple[int, ...] = (4, 4, 4),
strides: Tuple[int, ...] = (2, 2, 2),
max_pool_kernel: int = 2,
dropout_rate: float = 0.2,
activation: Optional[str] = "relu",
) -> None:
"""Initialization of the LeNet network.
Args:
channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).
strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2).
max_pool_kernel (int): 2D max pooling kernel. Defaults to 2.
dropout_rate (float): The dropout rate. Defaults to 0.2.
activation (Optional[str]): The name of non-linear activation function. Defaults to relu.
Raises:
RuntimeError: if the number of hyperparameters does not match in length.
"""
super().__init__()
if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides):
raise RuntimeError("The number of the hyperparameters does not match.")
self.cnn = self._build_network(
channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation,
)
def _build_network(
self,
channels: Tuple[int, ...],
kernel_sizes: Tuple[int, ...],
strides: Tuple[int, ...],
max_pool_kernel: int,
dropout_rate: float,
activation: str,
) -> nn.Sequential:
# Load activation function.
activation_fn = activation_function(activation)
channels = list(channels)
in_channels = channels.pop(0)
configuration = zip(channels, kernel_sizes, strides)
modules = nn.ModuleList([])
for i, (out_channels, kernel_size, stride) in enumerate(configuration):
# Add max pool to reduce output size.
if i == len(channels) // 2:
modules.append(nn.MaxPool2d(max_pool_kernel))
if i == 0:
modules.append(
nn.Conv2d(
in_channels, out_channels, kernel_size, stride=stride, padding=1
)
)
else:
modules.append(
nn.Sequential(
activation_fn,
nn.BatchNorm2d(in_channels),
nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=1,
),
)
)
if dropout_rate:
modules.append(nn.Dropout2d(p=dropout_rate))
in_channels = out_channels
return nn.Sequential(*modules)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""The feedforward pass."""
# If batch dimenstion is missing, it needs to be added.
if len(x.shape) < 4:
x = x[(None,) * (4 - len(x.shape))]
return self.cnn(x)
|