summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/densenet.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/densenet.py')
-rw-r--r--src/text_recognizer/networks/densenet.py225
1 files changed, 0 insertions, 225 deletions
diff --git a/src/text_recognizer/networks/densenet.py b/src/text_recognizer/networks/densenet.py
deleted file mode 100644
index 7dc58d9..0000000
--- a/src/text_recognizer/networks/densenet.py
+++ /dev/null
@@ -1,225 +0,0 @@
-"""Defines a Densely Connected Convolutional Networks in PyTorch.
-
-Sources:
-https://arxiv.org/abs/1608.06993
-https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
-
-"""
-from typing import List, Optional, Union
-
-from einops.layers.torch import Rearrange
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-
-
-class _DenseLayer(nn.Module):
- """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2."""
-
- def __init__(
- self,
- in_channels: int,
- growth_rate: int,
- bn_size: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> None:
- super().__init__()
- activation_fn = activation_function(activation)
- self.dense_layer = [
- nn.BatchNorm2d(in_channels),
- activation_fn,
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=bn_size * growth_rate,
- kernel_size=1,
- stride=1,
- bias=False,
- ),
- nn.BatchNorm2d(bn_size * growth_rate),
- activation_fn,
- nn.Conv2d(
- in_channels=bn_size * growth_rate,
- out_channels=growth_rate,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- ),
- ]
- if dropout_rate:
- self.dense_layer.append(nn.Dropout(p=dropout_rate))
-
- self.dense_layer = nn.Sequential(*self.dense_layer)
-
- def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor:
- if isinstance(x, list):
- x = torch.cat(x, 1)
- return self.dense_layer(x)
-
-
-class _DenseBlock(nn.Module):
- def __init__(
- self,
- num_layers: int,
- in_channels: int,
- bn_size: int,
- growth_rate: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> None:
- super().__init__()
- self.dense_block = self._build_dense_blocks(
- num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation,
- )
-
- def _build_dense_blocks(
- self,
- num_layers: int,
- in_channels: int,
- bn_size: int,
- growth_rate: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> nn.ModuleList:
- dense_block = []
- for i in range(num_layers):
- dense_block.append(
- _DenseLayer(
- in_channels=in_channels + i * growth_rate,
- growth_rate=growth_rate,
- bn_size=bn_size,
- dropout_rate=dropout_rate,
- activation=activation,
- )
- )
- return nn.ModuleList(dense_block)
-
- def forward(self, x: Tensor) -> Tensor:
- feature_maps = [x]
- for layer in self.dense_block:
- x = layer(feature_maps)
- feature_maps.append(x)
- return torch.cat(feature_maps, 1)
-
-
-class _Transition(nn.Module):
- def __init__(
- self, in_channels: int, out_channels: int, activation: str = "relu",
- ) -> None:
- super().__init__()
- activation_fn = activation_function(activation)
- self.transition = nn.Sequential(
- nn.BatchNorm2d(in_channels),
- activation_fn,
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=1,
- stride=1,
- bias=False,
- ),
- nn.AvgPool2d(kernel_size=2, stride=2),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- return self.transition(x)
-
-
-class DenseNet(nn.Module):
- """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow."""
-
- def __init__(
- self,
- growth_rate: int = 32,
- block_config: List[int] = (6, 12, 24, 16),
- in_channels: int = 1,
- base_channels: int = 64,
- num_classes: int = 80,
- bn_size: int = 4,
- dropout_rate: float = 0,
- classifier: bool = True,
- activation: str = "relu",
- ) -> None:
- super().__init__()
- self.densenet = self._configure_densenet(
- in_channels,
- base_channels,
- num_classes,
- growth_rate,
- block_config,
- bn_size,
- dropout_rate,
- classifier,
- activation,
- )
-
- def _configure_densenet(
- self,
- in_channels: int,
- base_channels: int,
- num_classes: int,
- growth_rate: int,
- block_config: List[int],
- bn_size: int,
- dropout_rate: float,
- classifier: bool,
- activation: str,
- ) -> nn.Sequential:
- activation_fn = activation_function(activation)
- densenet = [
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=base_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- ),
- nn.BatchNorm2d(base_channels),
- activation_fn,
- ]
-
- num_features = base_channels
-
- for i, num_layers in enumerate(block_config):
- densenet.append(
- _DenseBlock(
- num_layers=num_layers,
- in_channels=num_features,
- bn_size=bn_size,
- growth_rate=growth_rate,
- dropout_rate=dropout_rate,
- activation=activation,
- )
- )
- num_features = num_features + num_layers * growth_rate
- if i != len(block_config) - 1:
- densenet.append(
- _Transition(
- in_channels=num_features,
- out_channels=num_features // 2,
- activation=activation,
- )
- )
- num_features = num_features // 2
-
- densenet.append(activation_fn)
-
- if classifier:
- densenet.append(nn.AdaptiveAvgPool2d((1, 1)))
- densenet.append(Rearrange("b c h w -> b (c h w)"))
- densenet.append(
- nn.Linear(in_features=num_features, out_features=num_classes)
- )
-
- return nn.Sequential(*densenet)
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass of Densenet."""
- # If batch dimenstion is missing, it will be added.
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
- return self.densenet(x)