diff options
Diffstat (limited to 'src/text_recognizer/networks/densenet.py')
-rw-r--r-- | src/text_recognizer/networks/densenet.py | 225 |
1 files changed, 0 insertions, 225 deletions
diff --git a/src/text_recognizer/networks/densenet.py b/src/text_recognizer/networks/densenet.py deleted file mode 100644 index 7dc58d9..0000000 --- a/src/text_recognizer/networks/densenet.py +++ /dev/null @@ -1,225 +0,0 @@ -"""Defines a Densely Connected Convolutional Networks in PyTorch. - -Sources: -https://arxiv.org/abs/1608.06993 -https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py - -""" -from typing import List, Optional, Union - -from einops.layers.torch import Rearrange -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function - - -class _DenseLayer(nn.Module): - """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2.""" - - def __init__( - self, - in_channels: int, - growth_rate: int, - bn_size: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - activation_fn = activation_function(activation) - self.dense_layer = [ - nn.BatchNorm2d(in_channels), - activation_fn, - nn.Conv2d( - in_channels=in_channels, - out_channels=bn_size * growth_rate, - kernel_size=1, - stride=1, - bias=False, - ), - nn.BatchNorm2d(bn_size * growth_rate), - activation_fn, - nn.Conv2d( - in_channels=bn_size * growth_rate, - out_channels=growth_rate, - kernel_size=3, - stride=1, - padding=1, - bias=False, - ), - ] - if dropout_rate: - self.dense_layer.append(nn.Dropout(p=dropout_rate)) - - self.dense_layer = nn.Sequential(*self.dense_layer) - - def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor: - if isinstance(x, list): - x = torch.cat(x, 1) - return self.dense_layer(x) - - -class _DenseBlock(nn.Module): - def __init__( - self, - num_layers: int, - in_channels: int, - bn_size: int, - growth_rate: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - self.dense_block = self._build_dense_blocks( - num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation, - ) - - def _build_dense_blocks( - self, - num_layers: int, - in_channels: int, - bn_size: int, - growth_rate: int, - dropout_rate: float, - activation: str = "relu", - ) -> nn.ModuleList: - dense_block = [] - for i in range(num_layers): - dense_block.append( - _DenseLayer( - in_channels=in_channels + i * growth_rate, - growth_rate=growth_rate, - bn_size=bn_size, - dropout_rate=dropout_rate, - activation=activation, - ) - ) - return nn.ModuleList(dense_block) - - def forward(self, x: Tensor) -> Tensor: - feature_maps = [x] - for layer in self.dense_block: - x = layer(feature_maps) - feature_maps.append(x) - return torch.cat(feature_maps, 1) - - -class _Transition(nn.Module): - def __init__( - self, in_channels: int, out_channels: int, activation: str = "relu", - ) -> None: - super().__init__() - activation_fn = activation_function(activation) - self.transition = nn.Sequential( - nn.BatchNorm2d(in_channels), - activation_fn, - nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - stride=1, - bias=False, - ), - nn.AvgPool2d(kernel_size=2, stride=2), - ) - - def forward(self, x: Tensor) -> Tensor: - return self.transition(x) - - -class DenseNet(nn.Module): - """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow.""" - - def __init__( - self, - growth_rate: int = 32, - block_config: List[int] = (6, 12, 24, 16), - in_channels: int = 1, - base_channels: int = 64, - num_classes: int = 80, - bn_size: int = 4, - dropout_rate: float = 0, - classifier: bool = True, - activation: str = "relu", - ) -> None: - super().__init__() - self.densenet = self._configure_densenet( - in_channels, - base_channels, - num_classes, - growth_rate, - block_config, - bn_size, - dropout_rate, - classifier, - activation, - ) - - def _configure_densenet( - self, - in_channels: int, - base_channels: int, - num_classes: int, - growth_rate: int, - block_config: List[int], - bn_size: int, - dropout_rate: float, - classifier: bool, - activation: str, - ) -> nn.Sequential: - activation_fn = activation_function(activation) - densenet = [ - nn.Conv2d( - in_channels=in_channels, - out_channels=base_channels, - kernel_size=3, - stride=1, - padding=1, - bias=False, - ), - nn.BatchNorm2d(base_channels), - activation_fn, - ] - - num_features = base_channels - - for i, num_layers in enumerate(block_config): - densenet.append( - _DenseBlock( - num_layers=num_layers, - in_channels=num_features, - bn_size=bn_size, - growth_rate=growth_rate, - dropout_rate=dropout_rate, - activation=activation, - ) - ) - num_features = num_features + num_layers * growth_rate - if i != len(block_config) - 1: - densenet.append( - _Transition( - in_channels=num_features, - out_channels=num_features // 2, - activation=activation, - ) - ) - num_features = num_features // 2 - - densenet.append(activation_fn) - - if classifier: - densenet.append(nn.AdaptiveAvgPool2d((1, 1))) - densenet.append(Rearrange("b c h w -> b (c h w)")) - densenet.append( - nn.Linear(in_features=num_features, out_features=num_classes) - ) - - return nn.Sequential(*densenet) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass of Densenet.""" - # If batch dimenstion is missing, it will be added. - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - return self.densenet(x) |