diff options
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r-- | src/text_recognizer/networks/fcn.py | 99 |
1 files changed, 99 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/fcn.py b/src/text_recognizer/networks/fcn.py new file mode 100644 index 0000000..f9c4fd4 --- /dev/null +++ b/src/text_recognizer/networks/fcn.py @@ -0,0 +1,99 @@ +"""Fully Convolutional Network (FCN) with dilated kernels for global context.""" +from typing import List, Tuple, Type +import torch +from torch import nn +from torch import Tensor + + +from text_recognizer.networks.util import activation_function + + +class _DilatedBlock(nn.Module): + def __init__( + self, + channels: List[int], + kernel_sizes: List[int], + dilations: List[int], + paddings: List[int], + activation_fn: Type[nn.Module], + ) -> None: + super().__init__() + self.dilation_conv = nn.Sequential( + nn.Conv2d( + in_channels=channels[0], + out_channels=channels[1], + kernel_size=kernel_sizes[0], + stride=1, + dilation=dilations[0], + padding=paddings[0], + ), + nn.Conv2d( + in_channels=channels[1], + out_channels=channels[1] // 2, + kernel_size=kernel_sizes[1], + stride=1, + dilation=dilations[1], + padding=paddings[1], + ), + ) + self.activation_fn = activation_fn + + self.conv = nn.Conv2d( + in_channels=channels[0], + out_channels=channels[1] // 2, + kernel_size=1, + dilation=1, + stride=1, + ) + + def forward(self, x: Tensor) -> Tensor: + residual = self.conv(x) + x = self.dilation_conv(x) + x = torch.cat((x, residual), dim=1) + return self.activation_fn(x) + + +class FCN(nn.Module): + def __init__( + self, + in_channels: int, + base_channels: int, + out_channels: int, + kernel_size: int, + dilations: Tuple[int] = (3, 7), + paddings: Tuple[int] = (9, 21), + num_blocks: int = 14, + activation: str = "elu", + ) -> None: + super().__init__() + self.kernel_sizes = [kernel_size] * num_blocks + self.channels = [in_channels] + [base_channels] * (num_blocks - 1) + self.out_channels = out_channels + self.dilations = [dilations[0]] * (num_blocks // 2) + [dilations[1]] * ( + num_blocks // 2 + ) + self.paddings = [paddings[0]] * (num_blocks // 2) + [paddings[1]] * ( + num_blocks // 2 + ) + self.activation_fn = activation_function(activation) + self.fcn = self._configure_fcn() + + def _configure_fcn(self) -> nn.Sequential: + layers = [] + for i in range(0, len(self.channels), 2): + layers.append( + _DilatedBlock( + self.channels[i : i + 2], + self.kernel_sizes[i : i + 2], + self.dilations[i : i + 2], + self.paddings[i : i + 2], + self.activation_fn, + ) + ) + layers.append( + nn.Conv2d(self.channels[-1], self.out_channels, kernel_size=1, stride=1) + ) + return nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + return self.fcn(x) |