diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-01 23:53:50 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-01 23:53:50 +0200 |
commit | 58ae7154aa945cfe5a46592cc1dfb28f0a4e51b3 (patch) | |
tree | c89c1b1a4cc1a499900f2700ab09e8535e2cfe99 /text_recognizer/networks/encoders | |
parent | 7ae1f8f9654dcea0a9a22310ac0665a5d3202f0f (diff) |
Working on new attention module
Diffstat (limited to 'text_recognizer/networks/encoders')
-rw-r--r-- | text_recognizer/networks/encoders/__init__.py | 2 | ||||
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet.py | 145 | ||||
-rw-r--r-- | text_recognizer/networks/encoders/residual_network.py | 310 | ||||
-rw-r--r-- | text_recognizer/networks/encoders/wide_resnet.py | 221 |
4 files changed, 678 insertions, 0 deletions
diff --git a/text_recognizer/networks/encoders/__init__.py b/text_recognizer/networks/encoders/__init__.py new file mode 100644 index 0000000..25aed0e --- /dev/null +++ b/text_recognizer/networks/encoders/__init__.py @@ -0,0 +1,2 @@ +"""Vision backbones.""" +from .efficientnet import EfficientNet diff --git a/text_recognizer/networks/encoders/efficientnet.py b/text_recognizer/networks/encoders/efficientnet.py new file mode 100644 index 0000000..61dea77 --- /dev/null +++ b/text_recognizer/networks/encoders/efficientnet.py @@ -0,0 +1,145 @@ +"""Efficient net b0 implementation.""" +import torch +from torch import nn +from torch import Tensor + + +class ConvNorm(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding: int, + groups: int = 1, + ) -> None: + super().__init__() + self.block = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=False, + ), + nn.BatchNorm2d(num_features=out_channels), + nn.SiLU(inplace=True), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.block(x) + + +class SqueezeExcite(nn.Module): + def __init__(self, in_channels: int, reduce_dim: int) -> None: + super().__init__() + self.se = nn.Sequential( + nn.AdaptiveAvgPool2d(1), # [C, H, W] -> [C, 1, 1] + nn.Conv2d(in_channels=in_channels, out_channels=reduce_dim, kernel_size=1), + nn.SiLU(), + nn.Conv2d(in_channels=reduce_dim, out_channels=in_channels, kernel_size=1), + nn.Sigmoid(), + ) + + def forward(self, x: Tensor) -> Tensor: + return x * self.se(x) + + +class InvertedResidulaBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding: int, + expand_ratio: float, + reduction: int = 4, + survival_prob: float = 0.8, + ) -> None: + super().__init__() + self.survival_prob = survival_prob + self.use_residual = in_channels == out_channels and stride == 1 + hidden_dim = in_channels * expand_ratio + self.expand = in_channels != hidden_dim + reduce_dim = in_channels // reduction + + if self.expand: + self.expand_conv = ConvNorm( + in_channels, hidden_dim, kernel_size=3, stride=1, padding=1 + ) + + self.conv = nn.Sequential( + ConvNorm( + hidden_dim, hidden_dim, kernel_size, stride, padding, groups=hidden_dim + ), + SqueezeExcite(hidden_dim, reduce_dim), + nn.Conv2d( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + bias=False, + ), + nn.BatchNorm2d(num_features=out_channels), + ) + + def stochastic_depth(self, x: Tensor) -> Tensor: + if not self.training: + return x + + binary_tensor = ( + torch.rand(x.shape[0], 1, 1, 1, device=x.device) < self.survival_prob + ) + return torch.div(x, self.survival_prob) * binary_tensor + + def forward(self, x: Tensor) -> Tensor: + out = self.expand_conv(x) if self.expand else x + if self.use_residual: + return self.stochastic_depth(self.conv(out)) + x + return self.conv(out) + + +class EfficientNet(nn.Module): + """Efficient net b0 backbone.""" + + def __init__(self) -> None: + super().__init__() + self.base_model = [ + # expand_ratio, channels, repeats, stride, kernel_size + [1, 16, 1, 1, 3], + [6, 24, 2, 2, 3], + [6, 40, 2, 2, 5], + [6, 80, 3, 2, 3], + [6, 112, 3, 1, 5], + [6, 192, 4, 2, 5], + [6, 320, 1, 1, 3], + ] + + self.backbone = self._build_b0() + + def _build_b0(self) -> nn.Sequential: + in_channels = 32 + layers = [ConvNorm(1, in_channels, 3, stride=2, padding=1)] + + for expand_ratio, out_channels, repeats, stride, kernel_size in self.base_model: + for i in range(repeats): + layers.append( + InvertedResidulaBlock( + in_channels, + out_channels, + expand_ratio=expand_ratio, + stride=stride if i == 0 else 1, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) + in_channels = out_channels + layers.append(ConvNorm(in_channels, 256, kernel_size=1, stride=1, padding=0)) + + return nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + return self.backbone(x) diff --git a/text_recognizer/networks/encoders/residual_network.py b/text_recognizer/networks/encoders/residual_network.py new file mode 100644 index 0000000..c33f419 --- /dev/null +++ b/text_recognizer/networks/encoders/residual_network.py @@ -0,0 +1,310 @@ +"""Residual CNN.""" +from functools import partial +from typing import Callable, Dict, List, Optional, Type, Union + +from einops.layers.torch import Rearrange, Reduce +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function + + +class Conv2dAuto(nn.Conv2d): + """Convolution with auto padding based on kernel size.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2) + + +def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential: + """3x3 convolution with batch norm.""" + conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) + return nn.Sequential( + conv3x3(in_channels, out_channels, *args, **kwargs), + nn.BatchNorm2d(out_channels), + ) + + +class IdentityBlock(nn.Module): + """Residual with identity block.""" + + def __init__( + self, in_channels: int, out_channels: int, activation: str = "relu" + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.blocks = nn.Identity() + self.activation_fn = activation_function(activation) + self.shortcut = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + residual = x + if self.apply_shortcut: + residual = self.shortcut(x) + x = self.blocks(x) + x += residual + x = self.activation_fn(x) + return x + + @property + def apply_shortcut(self) -> bool: + """Check if shortcut should be applied.""" + return self.in_channels != self.out_channels + + +class ResidualBlock(IdentityBlock): + """Residual with nonlinear shortcut.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + expansion: int = 1, + downsampling: int = 1, + *args, + **kwargs + ) -> None: + """Short summary. + + Args: + in_channels (int): Number of in channels. + out_channels (int): umber of out channels. + expansion (int): Expansion factor of the out channels. Defaults to 1. + downsampling (int): Downsampling factor used in stride. Defaults to 1. + *args (type): Extra arguments. + **kwargs (type): Extra key value arguments. + + """ + super().__init__(in_channels, out_channels, *args, **kwargs) + self.expansion = expansion + self.downsampling = downsampling + + self.shortcut = ( + nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.expanded_channels, + kernel_size=1, + stride=self.downsampling, + bias=False, + ), + nn.BatchNorm2d(self.expanded_channels), + ) + if self.apply_shortcut + else None + ) + + @property + def expanded_channels(self) -> int: + """Computes the expanded output channels.""" + return self.out_channels * self.expansion + + @property + def apply_shortcut(self) -> bool: + """Check if shortcut should be applied.""" + return self.in_channels != self.expanded_channels + + +class BasicBlock(ResidualBlock): + """Basic ResNet block.""" + + expansion = 1 + + def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: + super().__init__(in_channels, out_channels, *args, **kwargs) + self.blocks = nn.Sequential( + conv_bn( + in_channels=self.in_channels, + out_channels=self.out_channels, + bias=False, + stride=self.downsampling, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.expanded_channels, + bias=False, + ), + ) + + +class BottleNeckBlock(ResidualBlock): + """Bottleneck block to increase depth while minimizing parameter size.""" + + expansion = 4 + + def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: + super().__init__(in_channels, out_channels, *args, **kwargs) + self.blocks = nn.Sequential( + conv_bn( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=self.downsampling, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.expanded_channels, + kernel_size=1, + ), + ) + + +class ResidualLayer(nn.Module): + """ResNet layer.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + block: BasicBlock = BasicBlock, + num_blocks: int = 1, + *args, + **kwargs + ) -> None: + super().__init__() + downsampling = 2 if in_channels != out_channels else 1 + self.blocks = nn.Sequential( + block( + in_channels, out_channels, *args, **kwargs, downsampling=downsampling + ), + *[ + block( + out_channels * block.expansion, + out_channels, + downsampling=1, + *args, + **kwargs + ) + for _ in range(num_blocks - 1) + ] + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x = self.blocks(x) + return x + + +class ResidualNetworkEncoder(nn.Module): + """Encoder network.""" + + def __init__( + self, + in_channels: int = 1, + block_sizes: Union[int, List[int]] = (32, 64), + depths: Union[int, List[int]] = (2, 2), + activation: str = "relu", + block: Type[nn.Module] = BasicBlock, + levels: int = 1, + *args, + **kwargs + ) -> None: + super().__init__() + self.block_sizes = ( + block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels + ) + self.depths = depths if isinstance(depths, list) else [depths] * levels + self.activation = activation + self.gate = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=self.block_sizes[0], + kernel_size=7, + stride=2, + padding=1, + bias=False, + ), + nn.BatchNorm2d(self.block_sizes[0]), + activation_function(self.activation), + # nn.MaxPool2d(kernel_size=2, stride=2, padding=1), + ) + + self.blocks = self._configure_blocks(block) + + def _configure_blocks( + self, block: Type[nn.Module], *args, **kwargs + ) -> nn.Sequential: + channels = [self.block_sizes[0]] + list( + zip(self.block_sizes, self.block_sizes[1:]) + ) + blocks = [ + ResidualLayer( + in_channels=channels[0], + out_channels=channels[0], + num_blocks=self.depths[0], + block=block, + activation=self.activation, + *args, + **kwargs + ) + ] + blocks += [ + ResidualLayer( + in_channels=in_channels * block.expansion, + out_channels=out_channels, + num_blocks=num_blocks, + block=block, + activation=self.activation, + *args, + **kwargs + ) + for (in_channels, out_channels), num_blocks in zip( + channels[1:], self.depths[1:] + ) + ] + + return nn.Sequential(*blocks) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) == 3: + x = x.unsqueeze(0) + x = self.gate(x) + x = self.blocks(x) + return x + + +class ResidualNetworkDecoder(nn.Module): + """Classification head.""" + + def __init__(self, in_features: int, num_classes: int = 80) -> None: + super().__init__() + self.decoder = nn.Sequential( + Reduce("b c h w -> b c", "mean"), + nn.Linear(in_features=in_features, out_features=num_classes), + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + return self.decoder(x) + + +class ResidualNetwork(nn.Module): + """Full residual network.""" + + def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None: + super().__init__() + self.encoder = ResidualNetworkEncoder(in_channels, *args, **kwargs) + self.decoder = ResidualNetworkDecoder( + in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels, + num_classes=num_classes, + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x = self.encoder(x) + x = self.decoder(x) + return x diff --git a/text_recognizer/networks/encoders/wide_resnet.py b/text_recognizer/networks/encoders/wide_resnet.py new file mode 100644 index 0000000..b767778 --- /dev/null +++ b/text_recognizer/networks/encoders/wide_resnet.py @@ -0,0 +1,221 @@ +"""Wide Residual CNN.""" +from functools import partial +from typing import Callable, Dict, List, Optional, Type, Union + +from einops.layers.torch import Reduce +import numpy as np +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """Helper function for a 3x3 2d convolution.""" + return nn.Conv2d( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + ) + + +def conv_init(module: Type[nn.Module]) -> None: + """Initializes the weights for convolution and batchnorms.""" + classname = module.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2)) + nn.init.constant_(module.bias, 0) + elif classname.find("BatchNorm") != -1: + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + +class WideBlock(nn.Module): + """Block used in WideResNet.""" + + def __init__( + self, + in_planes: int, + out_planes: int, + dropout_rate: float, + stride: int = 1, + activation: str = "relu", + ) -> None: + super().__init__() + self.in_planes = in_planes + self.out_planes = out_planes + self.dropout_rate = dropout_rate + self.stride = stride + self.activation = activation_function(activation) + + # Build blocks. + self.blocks = nn.Sequential( + nn.BatchNorm2d(self.in_planes), + self.activation, + conv3x3(in_planes=self.in_planes, out_planes=self.out_planes), + nn.Dropout(p=self.dropout_rate), + nn.BatchNorm2d(self.out_planes), + self.activation, + conv3x3( + in_planes=self.out_planes, + out_planes=self.out_planes, + stride=self.stride, + ), + ) + + self.shortcut = ( + nn.Sequential( + nn.Conv2d( + in_channels=self.in_planes, + out_channels=self.out_planes, + kernel_size=1, + stride=self.stride, + bias=False, + ), + ) + if self._apply_shortcut + else None + ) + + @property + def _apply_shortcut(self) -> bool: + """If shortcut should be applied or not.""" + return self.stride != 1 or self.in_planes != self.out_planes + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + residual = x + if self._apply_shortcut: + residual = self.shortcut(x) + x = self.blocks(x) + x += residual + return x + + +class WideResidualNetwork(nn.Module): + """WideResNet for character predictions. + + Can be used for classification or encoding of images to a latent vector. + + """ + + def __init__( + self, + in_channels: int = 1, + in_planes: int = 16, + num_classes: int = 80, + depth: int = 16, + width_factor: int = 10, + dropout_rate: float = 0.0, + num_layers: int = 3, + block: Type[nn.Module] = WideBlock, + num_stages: Optional[List[int]] = None, + activation: str = "relu", + use_decoder: bool = True, + ) -> None: + """The initialization of the WideResNet. + + Args: + in_channels (int): Number of input channels. Defaults to 1. + in_planes (int): Number of channels to use in the first output kernel. Defaults to 16. + num_classes (int): Number of classes. Defaults to 80. + depth (int): Set the number of blocks to use. Defaults to 16. + width_factor (int): Factor for scaling the number of channels in the network. Defaults to 10. + dropout_rate (float): The dropout rate. Defaults to 0.0. + num_layers (int): Number of layers of blocks. Defaults to 3. + block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock. + num_stages (List[int]): If given, will use these channel values. Defaults to None. + activation (str): Name of the activation to use. Defaults to "relu". + use_decoder (bool): If True, the network output character predictions, if False, the network outputs a + latent vector. Defaults to True. + + Raises: + RuntimeError: If the depth is not of the size `6n+4`. + + """ + + super().__init__() + if (depth - 4) % 6 != 0: + raise RuntimeError("Wide-resnet depth should be 6n+4") + self.in_channels = in_channels + self.in_planes = in_planes + self.num_classes = num_classes + self.num_blocks = (depth - 4) // 6 + self.width_factor = width_factor + self.num_layers = num_layers + self.block = block + self.dropout_rate = dropout_rate + self.activation = activation_function(activation) + + if num_stages is None: + self.num_stages = [self.in_planes] + [ + self.in_planes * 2 ** n * self.width_factor + for n in range(self.num_layers) + ] + else: + self.num_stages = [self.in_planes] + num_stages + + self.num_stages = list(zip(self.num_stages, self.num_stages[1:])) + self.strides = [1] + [2] * (self.num_layers - 1) + + self.encoder = nn.Sequential( + conv3x3(in_planes=self.in_channels, out_planes=self.in_planes), + *[ + self._configure_wide_layer( + in_planes=in_planes, + out_planes=out_planes, + stride=stride, + activation=activation, + ) + for (in_planes, out_planes), stride in zip( + self.num_stages, self.strides + ) + ], + ) + + self.decoder = ( + nn.Sequential( + nn.BatchNorm2d(self.num_stages[-1][-1], momentum=0.8), + self.activation, + Reduce("b c h w -> b c", "mean"), + nn.Linear( + in_features=self.num_stages[-1][-1], out_features=self.num_classes + ), + ) + if use_decoder + else None + ) + + # self.apply(conv_init) + + def _configure_wide_layer( + self, in_planes: int, out_planes: int, stride: int, activation: str + ) -> List: + strides = [stride] + [1] * (self.num_blocks - 1) + planes = [out_planes] * len(strides) + planes = [(in_planes, out_planes)] + list(zip(planes, planes[1:])) + return nn.Sequential( + *[ + self.block( + in_planes=in_planes, + out_planes=out_planes, + dropout_rate=self.dropout_rate, + stride=stride, + activation=activation, + ) + for (in_planes, out_planes), stride in zip(planes, strides) + ] + ) + + def forward(self, x: Tensor) -> Tensor: + """Feedforward pass.""" + if len(x.shape) < 4: + x = x[(None,) * int(4 - len(x.shape))] + x = self.encoder(x) + if self.decoder is not None: + x = self.decoder(x) + return x |