From 73ae250d7993fa48eccff4042ecd6bf768650bf3 Mon Sep 17 00:00:00 2001
From: aktersnurra <grydholm@kth.se>
Date: Wed, 18 Nov 2020 23:35:35 +0100
Subject: UNet implemented.

---
 src/text_recognizer/networks/__init__.py   |  2 -
 src/text_recognizer/networks/sparse_mlp.py | 78 ------------------------------
 src/text_recognizer/networks/unet.py       | 64 ++++++++++++++++--------
 3 files changed, 44 insertions(+), 100 deletions(-)
 delete mode 100644 src/text_recognizer/networks/sparse_mlp.py

(limited to 'src/text_recognizer/networks')

diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index 2cc1137..67e245c 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -6,7 +6,6 @@ from .densenet import DenseNet
 from .lenet import LeNet
 from .mlp import MLP
 from .residual_network import ResidualNetwork, ResidualNetworkEncoder
-from .sparse_mlp import SparseMLP
 from .transformer import Transformer
 from .util import sliding_window
 from .wide_resnet import WideResidualNetwork
@@ -22,6 +21,5 @@ __all__ = [
     "ResidualNetworkEncoder",
     "sliding_window",
     "Transformer",
-    "SparseMLP",
     "WideResidualNetwork",
 ]
diff --git a/src/text_recognizer/networks/sparse_mlp.py b/src/text_recognizer/networks/sparse_mlp.py
deleted file mode 100644
index 53cf166..0000000
--- a/src/text_recognizer/networks/sparse_mlp.py
+++ /dev/null
@@ -1,78 +0,0 @@
-"""Defines the Sparse MLP network."""
-from typing import Callable, Dict, List, Optional, Union
-import warnings
-
-from einops.layers.torch import Rearrange
-from pytorch_block_sparse import BlockSparseLinear
-import torch
-from torch import nn
-
-from text_recognizer.networks.util import activation_function
-
-warnings.filterwarnings("ignore", category=DeprecationWarning)
-
-
-class SparseMLP(nn.Module):
-    """Sparse multi layered perceptron network."""
-
-    def __init__(
-        self,
-        input_size: int = 784,
-        num_classes: int = 10,
-        hidden_size: Union[int, List] = 128,
-        num_layers: int = 3,
-        density: float = 0.1,
-        activation_fn: str = "relu",
-    ) -> None:
-        """Initialization of the MLP network.
-
-        Args:
-            input_size (int): The input shape of the network. Defaults to 784.
-            num_classes (int): Number of classes in the dataset. Defaults to 10.
-            hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128.
-            num_layers (int): The number of hidden layers. Defaults to 3.
-            density (float): The density of activation at each layer. Default to 0.1.
-            activation_fn (str): Name of the activation function in the hidden layers. Defaults to
-                relu.
-
-        """
-        super().__init__()
-
-        activation_fn = activation_function(activation_fn)
-
-        if isinstance(hidden_size, int):
-            hidden_size = [hidden_size] * num_layers
-
-        self.layers = [
-            Rearrange("b c h w -> b (c h w)"),
-            nn.Linear(in_features=input_size, out_features=hidden_size[0]),
-            activation_fn,
-        ]
-
-        for i in range(num_layers - 1):
-            self.layers += [
-                BlockSparseLinear(
-                    in_features=hidden_size[i],
-                    out_features=hidden_size[i + 1],
-                    density=density,
-                ),
-                activation_fn,
-            ]
-
-        self.layers.append(
-            nn.Linear(in_features=hidden_size[-1], out_features=num_classes)
-        )
-
-        self.layers = nn.Sequential(*self.layers)
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        """The feedforward pass."""
-        # If batch dimenstion is missing, it needs to be added.
-        if len(x.shape) < 4:
-            x = x[(None,) * (4 - len(x.shape))]
-        return self.layers(x)
-
-    @property
-    def __name__(self) -> str:
-        """Returns the name of the network."""
-        return "mlp"
diff --git a/src/text_recognizer/networks/unet.py b/src/text_recognizer/networks/unet.py
index eb4188b..51f242a 100644
--- a/src/text_recognizer/networks/unet.py
+++ b/src/text_recognizer/networks/unet.py
@@ -1,5 +1,5 @@
 """UNet for segmentation."""
-from typing import List, Tuple
+from typing import List, Optional, Tuple, Union
 
 import torch
 from torch import nn
@@ -39,16 +39,23 @@ class DownSamplingBlock(nn.Module):
     """Basic down sampling block."""
 
     def __init__(
-        self, channels: List[int], activation: str, pooling_kernel: int = 2
+        self,
+        channels: List[int],
+        activation: str,
+        pooling_kernel: Union[int, bool] = 2,
     ) -> None:
         super().__init__()
         self.conv_block = ConvBlock(channels, activation)
-        self.down_sampling = nn.MaxPool2d(pooling_kernel)
+        self.down_sampling = nn.MaxPool2d(pooling_kernel) if pooling_kernel else None
 
     def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
         """Return the convolutional block output and a down sampled tensor."""
         x = self.conv_block(x)
-        return self.down_sampling(x), x
+        if self.down_sampling is not None:
+            x_down = self.down_sampling(x)
+        else:
+            x_down = None
+        return x_down, x
 
 
 class UpSamplingBlock(nn.Module):
@@ -63,10 +70,11 @@ class UpSamplingBlock(nn.Module):
             scale_factor=scale_factor, mode="bilinear", align_corners=True
         )
 
-    def forward(self, x: Tensor, x_skip: Tensor) -> Tensor:
+    def forward(self, x: Tensor, x_skip: Optional[Tensor] = None) -> Tensor:
         """Apply the up sampling and convolutional block."""
         x = self.up_sampling(x)
-        x = torch.cat((x, x_skip), dim=1)
+        if x_skip is not None:
+            x = torch.cat((x, x_skip), dim=1)
         return self.conv_block(x)
 
 
@@ -77,6 +85,7 @@ class UNet(nn.Module):
         self,
         in_channels: int = 1,
         base_channels: int = 64,
+        num_classes: int = 3,
         depth: int = 4,
         out_channels: int = 3,
         activation: str = "relu",
@@ -84,27 +93,32 @@ class UNet(nn.Module):
         scale_factor: int = 2,
     ) -> None:
         super().__init__()
-        channels = [base_channels * 2 ** i for i in range(depth)]
-        self.down_sampling_blocks = self._configure_down_sampling_blocks(
+        self.depth = depth
+        channels = [1] + [base_channels * 2 ** i for i in range(depth)]
+        self.encoder_blocks = self._configure_down_sampling_blocks(
             channels, activation, pooling_kernel
         )
-        self.up_sampling_blocks = self._configure_up_sampling_blocks(
+        self.decoder_blocks = self._configure_up_sampling_blocks(
             channels, activation, scale_factor
         )
 
+        self.head = nn.Conv2d(base_channels, num_classes, kernel_size=1)
+
     def _configure_down_sampling_blocks(
         self, channels: List[int], activation: str, pooling_kernel: int
     ) -> nn.ModuleList:
-        return nn.ModuleList(
-            [
+        blocks = nn.ModuleList([])
+        for i in range(len(channels) - 1):
+            pooling_kernel = pooling_kernel if i < self.depth - 1 else False
+            blocks += [
                 DownSamplingBlock(
                     [channels[i], channels[i + 1], channels[i + 1]],
                     activation,
                     pooling_kernel,
                 )
-                for i in range(len(channels))
             ]
-        )
+
+        return blocks
 
     def _configure_up_sampling_blocks(
         self,
@@ -112,23 +126,33 @@ class UNet(nn.Module):
         activation: str,
         scale_factor: int,
     ) -> nn.ModuleList:
+        channels.reverse()
         return nn.ModuleList(
             [
                 UpSamplingBlock(
-                    [channels[i], channels[i + 1], channels[i + 1]],
+                    [channels[i] + channels[i + 1], channels[i + 1], channels[i + 1]],
                     activation,
                     scale_factor,
                 )
+                for i in range(len(channels) - 2)
             ]
-            for i in range(len(channels))
         )
 
-    def down_sampling(self, x: Tensor) -> List[Tensor]:
+    def encode(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
         x_skips = []
-        for block in self.down_sampling_blocks:
+        for block in self.encoder_blocks:
             x, x_skip = block(x)
-            x_skips.append(x_skip)
+            if x_skip is not None:
+                x_skips.append(x_skip)
         return x, x_skips
 
-    def up_sampling(self, x: Tensor, x_skips: List[Tensor]) -> Tensor:
-        pass
+    def decode(self, x: Tensor, x_skips: List[Tensor]) -> Tensor:
+        x = x_skips[-1]
+        for i, block in enumerate(self.decoder_blocks):
+            x = block(x, x_skips[-(i + 2)])
+        return x
+
+    def forward(self, x: Tensor) -> Tensor:
+        x, x_skips = self.encode(x)
+        x = self.decode(x, x_skips)
+        return self.head(x)
-- 
cgit v1.2.3-70-g09d2