summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/densenet.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/densenet.py')
-rw-r--r--src/text_recognizer/networks/densenet.py225
1 files changed, 225 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/densenet.py b/src/text_recognizer/networks/densenet.py
new file mode 100644
index 0000000..7dc58d9
--- /dev/null
+++ b/src/text_recognizer/networks/densenet.py
@@ -0,0 +1,225 @@
+"""Defines a Densely Connected Convolutional Networks in PyTorch.
+
+Sources:
+https://arxiv.org/abs/1608.06993
+https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
+
+"""
+from typing import List, Optional, Union
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+
+
+class _DenseLayer(nn.Module):
+ """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ growth_rate: int,
+ bn_size: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ activation_fn = activation_function(activation)
+ self.dense_layer = [
+ nn.BatchNorm2d(in_channels),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=bn_size * growth_rate,
+ kernel_size=1,
+ stride=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(bn_size * growth_rate),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=bn_size * growth_rate,
+ out_channels=growth_rate,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ ),
+ ]
+ if dropout_rate:
+ self.dense_layer.append(nn.Dropout(p=dropout_rate))
+
+ self.dense_layer = nn.Sequential(*self.dense_layer)
+
+ def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor:
+ if isinstance(x, list):
+ x = torch.cat(x, 1)
+ return self.dense_layer(x)
+
+
+class _DenseBlock(nn.Module):
+ def __init__(
+ self,
+ num_layers: int,
+ in_channels: int,
+ bn_size: int,
+ growth_rate: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.dense_block = self._build_dense_blocks(
+ num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation,
+ )
+
+ def _build_dense_blocks(
+ self,
+ num_layers: int,
+ in_channels: int,
+ bn_size: int,
+ growth_rate: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> nn.ModuleList:
+ dense_block = []
+ for i in range(num_layers):
+ dense_block.append(
+ _DenseLayer(
+ in_channels=in_channels + i * growth_rate,
+ growth_rate=growth_rate,
+ bn_size=bn_size,
+ dropout_rate=dropout_rate,
+ activation=activation,
+ )
+ )
+ return nn.ModuleList(dense_block)
+
+ def forward(self, x: Tensor) -> Tensor:
+ feature_maps = [x]
+ for layer in self.dense_block:
+ x = layer(feature_maps)
+ feature_maps.append(x)
+ return torch.cat(feature_maps, 1)
+
+
+class _Transition(nn.Module):
+ def __init__(
+ self, in_channels: int, out_channels: int, activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ activation_fn = activation_function(activation)
+ self.transition = nn.Sequential(
+ nn.BatchNorm2d(in_channels),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ bias=False,
+ ),
+ nn.AvgPool2d(kernel_size=2, stride=2),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.transition(x)
+
+
+class DenseNet(nn.Module):
+ """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow."""
+
+ def __init__(
+ self,
+ growth_rate: int = 32,
+ block_config: List[int] = (6, 12, 24, 16),
+ in_channels: int = 1,
+ base_channels: int = 64,
+ num_classes: int = 80,
+ bn_size: int = 4,
+ dropout_rate: float = 0,
+ classifier: bool = True,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.densenet = self._configure_densenet(
+ in_channels,
+ base_channels,
+ num_classes,
+ growth_rate,
+ block_config,
+ bn_size,
+ dropout_rate,
+ classifier,
+ activation,
+ )
+
+ def _configure_densenet(
+ self,
+ in_channels: int,
+ base_channels: int,
+ num_classes: int,
+ growth_rate: int,
+ block_config: List[int],
+ bn_size: int,
+ dropout_rate: float,
+ classifier: bool,
+ activation: str,
+ ) -> nn.Sequential:
+ activation_fn = activation_function(activation)
+ densenet = [
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=base_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(base_channels),
+ activation_fn,
+ ]
+
+ num_features = base_channels
+
+ for i, num_layers in enumerate(block_config):
+ densenet.append(
+ _DenseBlock(
+ num_layers=num_layers,
+ in_channels=num_features,
+ bn_size=bn_size,
+ growth_rate=growth_rate,
+ dropout_rate=dropout_rate,
+ activation=activation,
+ )
+ )
+ num_features = num_features + num_layers * growth_rate
+ if i != len(block_config) - 1:
+ densenet.append(
+ _Transition(
+ in_channels=num_features,
+ out_channels=num_features // 2,
+ activation=activation,
+ )
+ )
+ num_features = num_features // 2
+
+ densenet.append(activation_fn)
+
+ if classifier:
+ densenet.append(nn.AdaptiveAvgPool2d((1, 1)))
+ densenet.append(Rearrange("b c h w -> b (c h w)"))
+ densenet.append(
+ nn.Linear(in_features=num_features, out_features=num_classes)
+ )
+
+ return nn.Sequential(*densenet)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass of Densenet."""
+ # If batch dimenstion is missing, it will be added.
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ return self.densenet(x)