summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/fcn.py
blob: f9c4fd4c2acd776756eb028a36c837dca9de1372 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Fully Convolutional Network (FCN) with dilated kernels for global context."""
from typing import List, Tuple, Type
import torch
from torch import nn
from torch import Tensor


from text_recognizer.networks.util import activation_function


class _DilatedBlock(nn.Module):
    def __init__(
        self,
        channels: List[int],
        kernel_sizes: List[int],
        dilations: List[int],
        paddings: List[int],
        activation_fn: Type[nn.Module],
    ) -> None:
        super().__init__()
        self.dilation_conv = nn.Sequential(
            nn.Conv2d(
                in_channels=channels[0],
                out_channels=channels[1],
                kernel_size=kernel_sizes[0],
                stride=1,
                dilation=dilations[0],
                padding=paddings[0],
            ),
            nn.Conv2d(
                in_channels=channels[1],
                out_channels=channels[1] // 2,
                kernel_size=kernel_sizes[1],
                stride=1,
                dilation=dilations[1],
                padding=paddings[1],
            ),
        )
        self.activation_fn = activation_fn

        self.conv = nn.Conv2d(
            in_channels=channels[0],
            out_channels=channels[1] // 2,
            kernel_size=1,
            dilation=1,
            stride=1,
        )

    def forward(self, x: Tensor) -> Tensor:
        residual = self.conv(x)
        x = self.dilation_conv(x)
        x = torch.cat((x, residual), dim=1)
        return self.activation_fn(x)


class FCN(nn.Module):
    def __init__(
        self,
        in_channels: int,
        base_channels: int,
        out_channels: int,
        kernel_size: int,
        dilations: Tuple[int] = (3, 7),
        paddings: Tuple[int] = (9, 21),
        num_blocks: int = 14,
        activation: str = "elu",
    ) -> None:
        super().__init__()
        self.kernel_sizes = [kernel_size] * num_blocks
        self.channels = [in_channels] + [base_channels] * (num_blocks - 1)
        self.out_channels = out_channels
        self.dilations = [dilations[0]] * (num_blocks // 2) + [dilations[1]] * (
            num_blocks // 2
        )
        self.paddings = [paddings[0]] * (num_blocks // 2) + [paddings[1]] * (
            num_blocks // 2
        )
        self.activation_fn = activation_function(activation)
        self.fcn = self._configure_fcn()

    def _configure_fcn(self) -> nn.Sequential:
        layers = []
        for i in range(0, len(self.channels), 2):
            layers.append(
                _DilatedBlock(
                    self.channels[i : i + 2],
                    self.kernel_sizes[i : i + 2],
                    self.dilations[i : i + 2],
                    self.paddings[i : i + 2],
                    self.activation_fn,
                )
            )
        layers.append(
            nn.Conv2d(self.channels[-1], self.out_channels, kernel_size=1, stride=1)
        )
        return nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        return self.fcn(x)