From e3b039c9adb4bce42ede4cb682a3ae71e797539a Mon Sep 17 00:00:00 2001
From: aktersnurra <grydholm@kth.se>
Date: Wed, 2 Dec 2020 23:48:10 +0100
Subject: added segmentation network.

---
 src/text_recognizer/networks/fcn.py | 99 +++++++++++++++++++++++++++++++++++++
 1 file changed, 99 insertions(+)
 create mode 100644 src/text_recognizer/networks/fcn.py

diff --git a/src/text_recognizer/networks/fcn.py b/src/text_recognizer/networks/fcn.py
new file mode 100644
index 0000000..f9c4fd4
--- /dev/null
+++ b/src/text_recognizer/networks/fcn.py
@@ -0,0 +1,99 @@
+"""Fully Convolutional Network (FCN) with dilated kernels for global context."""
+from typing import List, Tuple, Type
+import torch
+from torch import nn
+from torch import Tensor
+
+
+from text_recognizer.networks.util import activation_function
+
+
+class _DilatedBlock(nn.Module):
+    def __init__(
+        self,
+        channels: List[int],
+        kernel_sizes: List[int],
+        dilations: List[int],
+        paddings: List[int],
+        activation_fn: Type[nn.Module],
+    ) -> None:
+        super().__init__()
+        self.dilation_conv = nn.Sequential(
+            nn.Conv2d(
+                in_channels=channels[0],
+                out_channels=channels[1],
+                kernel_size=kernel_sizes[0],
+                stride=1,
+                dilation=dilations[0],
+                padding=paddings[0],
+            ),
+            nn.Conv2d(
+                in_channels=channels[1],
+                out_channels=channels[1] // 2,
+                kernel_size=kernel_sizes[1],
+                stride=1,
+                dilation=dilations[1],
+                padding=paddings[1],
+            ),
+        )
+        self.activation_fn = activation_fn
+
+        self.conv = nn.Conv2d(
+            in_channels=channels[0],
+            out_channels=channels[1] // 2,
+            kernel_size=1,
+            dilation=1,
+            stride=1,
+        )
+
+    def forward(self, x: Tensor) -> Tensor:
+        residual = self.conv(x)
+        x = self.dilation_conv(x)
+        x = torch.cat((x, residual), dim=1)
+        return self.activation_fn(x)
+
+
+class FCN(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        base_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        dilations: Tuple[int] = (3, 7),
+        paddings: Tuple[int] = (9, 21),
+        num_blocks: int = 14,
+        activation: str = "elu",
+    ) -> None:
+        super().__init__()
+        self.kernel_sizes = [kernel_size] * num_blocks
+        self.channels = [in_channels] + [base_channels] * (num_blocks - 1)
+        self.out_channels = out_channels
+        self.dilations = [dilations[0]] * (num_blocks // 2) + [dilations[1]] * (
+            num_blocks // 2
+        )
+        self.paddings = [paddings[0]] * (num_blocks // 2) + [paddings[1]] * (
+            num_blocks // 2
+        )
+        self.activation_fn = activation_function(activation)
+        self.fcn = self._configure_fcn()
+
+    def _configure_fcn(self) -> nn.Sequential:
+        layers = []
+        for i in range(0, len(self.channels), 2):
+            layers.append(
+                _DilatedBlock(
+                    self.channels[i : i + 2],
+                    self.kernel_sizes[i : i + 2],
+                    self.dilations[i : i + 2],
+                    self.paddings[i : i + 2],
+                    self.activation_fn,
+                )
+            )
+        layers.append(
+            nn.Conv2d(self.channels[-1], self.out_channels, kernel_size=1, stride=1)
+        )
+        return nn.Sequential(*layers)
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self.fcn(x)
-- 
cgit v1.2.3-70-g09d2