From 58ae7154aa945cfe5a46592cc1dfb28f0a4e51b3 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 1 May 2021 23:53:50 +0200 Subject: Working on new attention module --- text_recognizer/networks/__init__.py | 2 +- text_recognizer/networks/backbones/__init__.py | 2 - text_recognizer/networks/backbones/efficientnet.py | 145 ---------- text_recognizer/networks/encoders/__init__.py | 2 + text_recognizer/networks/encoders/efficientnet.py | 145 ++++++++++ .../networks/encoders/residual_network.py | 310 +++++++++++++++++++++ text_recognizer/networks/encoders/wide_resnet.py | 221 +++++++++++++++ text_recognizer/networks/residual_network.py | 310 --------------------- text_recognizer/networks/transformer/attention.py | 119 ++++---- text_recognizer/networks/transformer/norm.py | 13 + text_recognizer/networks/wide_resnet.py | 221 --------------- 11 files changed, 741 insertions(+), 749 deletions(-) delete mode 100644 text_recognizer/networks/backbones/__init__.py delete mode 100644 text_recognizer/networks/backbones/efficientnet.py create mode 100644 text_recognizer/networks/encoders/__init__.py create mode 100644 text_recognizer/networks/encoders/efficientnet.py create mode 100644 text_recognizer/networks/encoders/residual_network.py create mode 100644 text_recognizer/networks/encoders/wide_resnet.py delete mode 100644 text_recognizer/networks/residual_network.py delete mode 100644 text_recognizer/networks/wide_resnet.py (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index 63b43b2..a9117f8 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -1,4 +1,4 @@ """Network modules""" -from .backbones import EfficientNet +from .encoders import EfficientNet from .vqvae import VQVAE from .cnn_transformer import CNNTransformer diff --git a/text_recognizer/networks/backbones/__init__.py b/text_recognizer/networks/backbones/__init__.py deleted file mode 100644 index 25aed0e..0000000 --- a/text_recognizer/networks/backbones/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Vision backbones.""" -from .efficientnet import EfficientNet diff --git a/text_recognizer/networks/backbones/efficientnet.py b/text_recognizer/networks/backbones/efficientnet.py deleted file mode 100644 index 61dea77..0000000 --- a/text_recognizer/networks/backbones/efficientnet.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Efficient net b0 implementation.""" -import torch -from torch import nn -from torch import Tensor - - -class ConvNorm(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int, - padding: int, - groups: int = 1, - ) -> None: - super().__init__() - self.block = nn.Sequential( - nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - groups=groups, - bias=False, - ), - nn.BatchNorm2d(num_features=out_channels), - nn.SiLU(inplace=True), - ) - - def forward(self, x: Tensor) -> Tensor: - return self.block(x) - - -class SqueezeExcite(nn.Module): - def __init__(self, in_channels: int, reduce_dim: int) -> None: - super().__init__() - self.se = nn.Sequential( - nn.AdaptiveAvgPool2d(1), # [C, H, W] -> [C, 1, 1] - nn.Conv2d(in_channels=in_channels, out_channels=reduce_dim, kernel_size=1), - nn.SiLU(), - nn.Conv2d(in_channels=reduce_dim, out_channels=in_channels, kernel_size=1), - nn.Sigmoid(), - ) - - def forward(self, x: Tensor) -> Tensor: - return x * self.se(x) - - -class InvertedResidulaBlock(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int, - padding: int, - expand_ratio: float, - reduction: int = 4, - survival_prob: float = 0.8, - ) -> None: - super().__init__() - self.survival_prob = survival_prob - self.use_residual = in_channels == out_channels and stride == 1 - hidden_dim = in_channels * expand_ratio - self.expand = in_channels != hidden_dim - reduce_dim = in_channels // reduction - - if self.expand: - self.expand_conv = ConvNorm( - in_channels, hidden_dim, kernel_size=3, stride=1, padding=1 - ) - - self.conv = nn.Sequential( - ConvNorm( - hidden_dim, hidden_dim, kernel_size, stride, padding, groups=hidden_dim - ), - SqueezeExcite(hidden_dim, reduce_dim), - nn.Conv2d( - in_channels=hidden_dim, - out_channels=out_channels, - kernel_size=1, - bias=False, - ), - nn.BatchNorm2d(num_features=out_channels), - ) - - def stochastic_depth(self, x: Tensor) -> Tensor: - if not self.training: - return x - - binary_tensor = ( - torch.rand(x.shape[0], 1, 1, 1, device=x.device) < self.survival_prob - ) - return torch.div(x, self.survival_prob) * binary_tensor - - def forward(self, x: Tensor) -> Tensor: - out = self.expand_conv(x) if self.expand else x - if self.use_residual: - return self.stochastic_depth(self.conv(out)) + x - return self.conv(out) - - -class EfficientNet(nn.Module): - """Efficient net b0 backbone.""" - - def __init__(self) -> None: - super().__init__() - self.base_model = [ - # expand_ratio, channels, repeats, stride, kernel_size - [1, 16, 1, 1, 3], - [6, 24, 2, 2, 3], - [6, 40, 2, 2, 5], - [6, 80, 3, 2, 3], - [6, 112, 3, 1, 5], - [6, 192, 4, 2, 5], - [6, 320, 1, 1, 3], - ] - - self.backbone = self._build_b0() - - def _build_b0(self) -> nn.Sequential: - in_channels = 32 - layers = [ConvNorm(1, in_channels, 3, stride=2, padding=1)] - - for expand_ratio, out_channels, repeats, stride, kernel_size in self.base_model: - for i in range(repeats): - layers.append( - InvertedResidulaBlock( - in_channels, - out_channels, - expand_ratio=expand_ratio, - stride=stride if i == 0 else 1, - kernel_size=kernel_size, - padding=kernel_size // 2, - ) - ) - in_channels = out_channels - layers.append(ConvNorm(in_channels, 256, kernel_size=1, stride=1, padding=0)) - - return nn.Sequential(*layers) - - def forward(self, x: Tensor) -> Tensor: - return self.backbone(x) diff --git a/text_recognizer/networks/encoders/__init__.py b/text_recognizer/networks/encoders/__init__.py new file mode 100644 index 0000000..25aed0e --- /dev/null +++ b/text_recognizer/networks/encoders/__init__.py @@ -0,0 +1,2 @@ +"""Vision backbones.""" +from .efficientnet import EfficientNet diff --git a/text_recognizer/networks/encoders/efficientnet.py b/text_recognizer/networks/encoders/efficientnet.py new file mode 100644 index 0000000..61dea77 --- /dev/null +++ b/text_recognizer/networks/encoders/efficientnet.py @@ -0,0 +1,145 @@ +"""Efficient net b0 implementation.""" +import torch +from torch import nn +from torch import Tensor + + +class ConvNorm(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding: int, + groups: int = 1, + ) -> None: + super().__init__() + self.block = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=False, + ), + nn.BatchNorm2d(num_features=out_channels), + nn.SiLU(inplace=True), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.block(x) + + +class SqueezeExcite(nn.Module): + def __init__(self, in_channels: int, reduce_dim: int) -> None: + super().__init__() + self.se = nn.Sequential( + nn.AdaptiveAvgPool2d(1), # [C, H, W] -> [C, 1, 1] + nn.Conv2d(in_channels=in_channels, out_channels=reduce_dim, kernel_size=1), + nn.SiLU(), + nn.Conv2d(in_channels=reduce_dim, out_channels=in_channels, kernel_size=1), + nn.Sigmoid(), + ) + + def forward(self, x: Tensor) -> Tensor: + return x * self.se(x) + + +class InvertedResidulaBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding: int, + expand_ratio: float, + reduction: int = 4, + survival_prob: float = 0.8, + ) -> None: + super().__init__() + self.survival_prob = survival_prob + self.use_residual = in_channels == out_channels and stride == 1 + hidden_dim = in_channels * expand_ratio + self.expand = in_channels != hidden_dim + reduce_dim = in_channels // reduction + + if self.expand: + self.expand_conv = ConvNorm( + in_channels, hidden_dim, kernel_size=3, stride=1, padding=1 + ) + + self.conv = nn.Sequential( + ConvNorm( + hidden_dim, hidden_dim, kernel_size, stride, padding, groups=hidden_dim + ), + SqueezeExcite(hidden_dim, reduce_dim), + nn.Conv2d( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + bias=False, + ), + nn.BatchNorm2d(num_features=out_channels), + ) + + def stochastic_depth(self, x: Tensor) -> Tensor: + if not self.training: + return x + + binary_tensor = ( + torch.rand(x.shape[0], 1, 1, 1, device=x.device) < self.survival_prob + ) + return torch.div(x, self.survival_prob) * binary_tensor + + def forward(self, x: Tensor) -> Tensor: + out = self.expand_conv(x) if self.expand else x + if self.use_residual: + return self.stochastic_depth(self.conv(out)) + x + return self.conv(out) + + +class EfficientNet(nn.Module): + """Efficient net b0 backbone.""" + + def __init__(self) -> None: + super().__init__() + self.base_model = [ + # expand_ratio, channels, repeats, stride, kernel_size + [1, 16, 1, 1, 3], + [6, 24, 2, 2, 3], + [6, 40, 2, 2, 5], + [6, 80, 3, 2, 3], + [6, 112, 3, 1, 5], + [6, 192, 4, 2, 5], + [6, 320, 1, 1, 3], + ] + + self.backbone = self._build_b0() + + def _build_b0(self) -> nn.Sequential: + in_channels = 32 + layers = [ConvNorm(1, in_channels, 3, stride=2, padding=1)] + + for expand_ratio, out_channels, repeats, stride, kernel_size in self.base_model: + for i in range(repeats): + layers.append( + InvertedResidulaBlock( + in_channels, + out_channels, + expand_ratio=expand_ratio, + stride=stride if i == 0 else 1, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) + in_channels = out_channels + layers.append(ConvNorm(in_channels, 256, kernel_size=1, stride=1, padding=0)) + + return nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + return self.backbone(x) diff --git a/text_recognizer/networks/encoders/residual_network.py b/text_recognizer/networks/encoders/residual_network.py new file mode 100644 index 0000000..c33f419 --- /dev/null +++ b/text_recognizer/networks/encoders/residual_network.py @@ -0,0 +1,310 @@ +"""Residual CNN.""" +from functools import partial +from typing import Callable, Dict, List, Optional, Type, Union + +from einops.layers.torch import Rearrange, Reduce +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function + + +class Conv2dAuto(nn.Conv2d): + """Convolution with auto padding based on kernel size.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2) + + +def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential: + """3x3 convolution with batch norm.""" + conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) + return nn.Sequential( + conv3x3(in_channels, out_channels, *args, **kwargs), + nn.BatchNorm2d(out_channels), + ) + + +class IdentityBlock(nn.Module): + """Residual with identity block.""" + + def __init__( + self, in_channels: int, out_channels: int, activation: str = "relu" + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.blocks = nn.Identity() + self.activation_fn = activation_function(activation) + self.shortcut = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + residual = x + if self.apply_shortcut: + residual = self.shortcut(x) + x = self.blocks(x) + x += residual + x = self.activation_fn(x) + return x + + @property + def apply_shortcut(self) -> bool: + """Check if shortcut should be applied.""" + return self.in_channels != self.out_channels + + +class ResidualBlock(IdentityBlock): + """Residual with nonlinear shortcut.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + expansion: int = 1, + downsampling: int = 1, + *args, + **kwargs + ) -> None: + """Short summary. + + Args: + in_channels (int): Number of in channels. + out_channels (int): umber of out channels. + expansion (int): Expansion factor of the out channels. Defaults to 1. + downsampling (int): Downsampling factor used in stride. Defaults to 1. + *args (type): Extra arguments. + **kwargs (type): Extra key value arguments. + + """ + super().__init__(in_channels, out_channels, *args, **kwargs) + self.expansion = expansion + self.downsampling = downsampling + + self.shortcut = ( + nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.expanded_channels, + kernel_size=1, + stride=self.downsampling, + bias=False, + ), + nn.BatchNorm2d(self.expanded_channels), + ) + if self.apply_shortcut + else None + ) + + @property + def expanded_channels(self) -> int: + """Computes the expanded output channels.""" + return self.out_channels * self.expansion + + @property + def apply_shortcut(self) -> bool: + """Check if shortcut should be applied.""" + return self.in_channels != self.expanded_channels + + +class BasicBlock(ResidualBlock): + """Basic ResNet block.""" + + expansion = 1 + + def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: + super().__init__(in_channels, out_channels, *args, **kwargs) + self.blocks = nn.Sequential( + conv_bn( + in_channels=self.in_channels, + out_channels=self.out_channels, + bias=False, + stride=self.downsampling, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.expanded_channels, + bias=False, + ), + ) + + +class BottleNeckBlock(ResidualBlock): + """Bottleneck block to increase depth while minimizing parameter size.""" + + expansion = 4 + + def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: + super().__init__(in_channels, out_channels, *args, **kwargs) + self.blocks = nn.Sequential( + conv_bn( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=self.downsampling, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.expanded_channels, + kernel_size=1, + ), + ) + + +class ResidualLayer(nn.Module): + """ResNet layer.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + block: BasicBlock = BasicBlock, + num_blocks: int = 1, + *args, + **kwargs + ) -> None: + super().__init__() + downsampling = 2 if in_channels != out_channels else 1 + self.blocks = nn.Sequential( + block( + in_channels, out_channels, *args, **kwargs, downsampling=downsampling + ), + *[ + block( + out_channels * block.expansion, + out_channels, + downsampling=1, + *args, + **kwargs + ) + for _ in range(num_blocks - 1) + ] + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x = self.blocks(x) + return x + + +class ResidualNetworkEncoder(nn.Module): + """Encoder network.""" + + def __init__( + self, + in_channels: int = 1, + block_sizes: Union[int, List[int]] = (32, 64), + depths: Union[int, List[int]] = (2, 2), + activation: str = "relu", + block: Type[nn.Module] = BasicBlock, + levels: int = 1, + *args, + **kwargs + ) -> None: + super().__init__() + self.block_sizes = ( + block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels + ) + self.depths = depths if isinstance(depths, list) else [depths] * levels + self.activation = activation + self.gate = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=self.block_sizes[0], + kernel_size=7, + stride=2, + padding=1, + bias=False, + ), + nn.BatchNorm2d(self.block_sizes[0]), + activation_function(self.activation), + # nn.MaxPool2d(kernel_size=2, stride=2, padding=1), + ) + + self.blocks = self._configure_blocks(block) + + def _configure_blocks( + self, block: Type[nn.Module], *args, **kwargs + ) -> nn.Sequential: + channels = [self.block_sizes[0]] + list( + zip(self.block_sizes, self.block_sizes[1:]) + ) + blocks = [ + ResidualLayer( + in_channels=channels[0], + out_channels=channels[0], + num_blocks=self.depths[0], + block=block, + activation=self.activation, + *args, + **kwargs + ) + ] + blocks += [ + ResidualLayer( + in_channels=in_channels * block.expansion, + out_channels=out_channels, + num_blocks=num_blocks, + block=block, + activation=self.activation, + *args, + **kwargs + ) + for (in_channels, out_channels), num_blocks in zip( + channels[1:], self.depths[1:] + ) + ] + + return nn.Sequential(*blocks) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) == 3: + x = x.unsqueeze(0) + x = self.gate(x) + x = self.blocks(x) + return x + + +class ResidualNetworkDecoder(nn.Module): + """Classification head.""" + + def __init__(self, in_features: int, num_classes: int = 80) -> None: + super().__init__() + self.decoder = nn.Sequential( + Reduce("b c h w -> b c", "mean"), + nn.Linear(in_features=in_features, out_features=num_classes), + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + return self.decoder(x) + + +class ResidualNetwork(nn.Module): + """Full residual network.""" + + def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None: + super().__init__() + self.encoder = ResidualNetworkEncoder(in_channels, *args, **kwargs) + self.decoder = ResidualNetworkDecoder( + in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels, + num_classes=num_classes, + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x = self.encoder(x) + x = self.decoder(x) + return x diff --git a/text_recognizer/networks/encoders/wide_resnet.py b/text_recognizer/networks/encoders/wide_resnet.py new file mode 100644 index 0000000..b767778 --- /dev/null +++ b/text_recognizer/networks/encoders/wide_resnet.py @@ -0,0 +1,221 @@ +"""Wide Residual CNN.""" +from functools import partial +from typing import Callable, Dict, List, Optional, Type, Union + +from einops.layers.torch import Reduce +import numpy as np +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """Helper function for a 3x3 2d convolution.""" + return nn.Conv2d( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + ) + + +def conv_init(module: Type[nn.Module]) -> None: + """Initializes the weights for convolution and batchnorms.""" + classname = module.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2)) + nn.init.constant_(module.bias, 0) + elif classname.find("BatchNorm") != -1: + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + +class WideBlock(nn.Module): + """Block used in WideResNet.""" + + def __init__( + self, + in_planes: int, + out_planes: int, + dropout_rate: float, + stride: int = 1, + activation: str = "relu", + ) -> None: + super().__init__() + self.in_planes = in_planes + self.out_planes = out_planes + self.dropout_rate = dropout_rate + self.stride = stride + self.activation = activation_function(activation) + + # Build blocks. + self.blocks = nn.Sequential( + nn.BatchNorm2d(self.in_planes), + self.activation, + conv3x3(in_planes=self.in_planes, out_planes=self.out_planes), + nn.Dropout(p=self.dropout_rate), + nn.BatchNorm2d(self.out_planes), + self.activation, + conv3x3( + in_planes=self.out_planes, + out_planes=self.out_planes, + stride=self.stride, + ), + ) + + self.shortcut = ( + nn.Sequential( + nn.Conv2d( + in_channels=self.in_planes, + out_channels=self.out_planes, + kernel_size=1, + stride=self.stride, + bias=False, + ), + ) + if self._apply_shortcut + else None + ) + + @property + def _apply_shortcut(self) -> bool: + """If shortcut should be applied or not.""" + return self.stride != 1 or self.in_planes != self.out_planes + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + residual = x + if self._apply_shortcut: + residual = self.shortcut(x) + x = self.blocks(x) + x += residual + return x + + +class WideResidualNetwork(nn.Module): + """WideResNet for character predictions. + + Can be used for classification or encoding of images to a latent vector. + + """ + + def __init__( + self, + in_channels: int = 1, + in_planes: int = 16, + num_classes: int = 80, + depth: int = 16, + width_factor: int = 10, + dropout_rate: float = 0.0, + num_layers: int = 3, + block: Type[nn.Module] = WideBlock, + num_stages: Optional[List[int]] = None, + activation: str = "relu", + use_decoder: bool = True, + ) -> None: + """The initialization of the WideResNet. + + Args: + in_channels (int): Number of input channels. Defaults to 1. + in_planes (int): Number of channels to use in the first output kernel. Defaults to 16. + num_classes (int): Number of classes. Defaults to 80. + depth (int): Set the number of blocks to use. Defaults to 16. + width_factor (int): Factor for scaling the number of channels in the network. Defaults to 10. + dropout_rate (float): The dropout rate. Defaults to 0.0. + num_layers (int): Number of layers of blocks. Defaults to 3. + block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock. + num_stages (List[int]): If given, will use these channel values. Defaults to None. + activation (str): Name of the activation to use. Defaults to "relu". + use_decoder (bool): If True, the network output character predictions, if False, the network outputs a + latent vector. Defaults to True. + + Raises: + RuntimeError: If the depth is not of the size `6n+4`. + + """ + + super().__init__() + if (depth - 4) % 6 != 0: + raise RuntimeError("Wide-resnet depth should be 6n+4") + self.in_channels = in_channels + self.in_planes = in_planes + self.num_classes = num_classes + self.num_blocks = (depth - 4) // 6 + self.width_factor = width_factor + self.num_layers = num_layers + self.block = block + self.dropout_rate = dropout_rate + self.activation = activation_function(activation) + + if num_stages is None: + self.num_stages = [self.in_planes] + [ + self.in_planes * 2 ** n * self.width_factor + for n in range(self.num_layers) + ] + else: + self.num_stages = [self.in_planes] + num_stages + + self.num_stages = list(zip(self.num_stages, self.num_stages[1:])) + self.strides = [1] + [2] * (self.num_layers - 1) + + self.encoder = nn.Sequential( + conv3x3(in_planes=self.in_channels, out_planes=self.in_planes), + *[ + self._configure_wide_layer( + in_planes=in_planes, + out_planes=out_planes, + stride=stride, + activation=activation, + ) + for (in_planes, out_planes), stride in zip( + self.num_stages, self.strides + ) + ], + ) + + self.decoder = ( + nn.Sequential( + nn.BatchNorm2d(self.num_stages[-1][-1], momentum=0.8), + self.activation, + Reduce("b c h w -> b c", "mean"), + nn.Linear( + in_features=self.num_stages[-1][-1], out_features=self.num_classes + ), + ) + if use_decoder + else None + ) + + # self.apply(conv_init) + + def _configure_wide_layer( + self, in_planes: int, out_planes: int, stride: int, activation: str + ) -> List: + strides = [stride] + [1] * (self.num_blocks - 1) + planes = [out_planes] * len(strides) + planes = [(in_planes, out_planes)] + list(zip(planes, planes[1:])) + return nn.Sequential( + *[ + self.block( + in_planes=in_planes, + out_planes=out_planes, + dropout_rate=self.dropout_rate, + stride=stride, + activation=activation, + ) + for (in_planes, out_planes), stride in zip(planes, strides) + ] + ) + + def forward(self, x: Tensor) -> Tensor: + """Feedforward pass.""" + if len(x.shape) < 4: + x = x[(None,) * int(4 - len(x.shape))] + x = self.encoder(x) + if self.decoder is not None: + x = self.decoder(x) + return x diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py deleted file mode 100644 index c33f419..0000000 --- a/text_recognizer/networks/residual_network.py +++ /dev/null @@ -1,310 +0,0 @@ -"""Residual CNN.""" -from functools import partial -from typing import Callable, Dict, List, Optional, Type, Union - -from einops.layers.torch import Rearrange, Reduce -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function - - -class Conv2dAuto(nn.Conv2d): - """Convolution with auto padding based on kernel size.""" - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2) - - -def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential: - """3x3 convolution with batch norm.""" - conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) - return nn.Sequential( - conv3x3(in_channels, out_channels, *args, **kwargs), - nn.BatchNorm2d(out_channels), - ) - - -class IdentityBlock(nn.Module): - """Residual with identity block.""" - - def __init__( - self, in_channels: int, out_channels: int, activation: str = "relu" - ) -> None: - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.blocks = nn.Identity() - self.activation_fn = activation_function(activation) - self.shortcut = nn.Identity() - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - residual = x - if self.apply_shortcut: - residual = self.shortcut(x) - x = self.blocks(x) - x += residual - x = self.activation_fn(x) - return x - - @property - def apply_shortcut(self) -> bool: - """Check if shortcut should be applied.""" - return self.in_channels != self.out_channels - - -class ResidualBlock(IdentityBlock): - """Residual with nonlinear shortcut.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - expansion: int = 1, - downsampling: int = 1, - *args, - **kwargs - ) -> None: - """Short summary. - - Args: - in_channels (int): Number of in channels. - out_channels (int): umber of out channels. - expansion (int): Expansion factor of the out channels. Defaults to 1. - downsampling (int): Downsampling factor used in stride. Defaults to 1. - *args (type): Extra arguments. - **kwargs (type): Extra key value arguments. - - """ - super().__init__(in_channels, out_channels, *args, **kwargs) - self.expansion = expansion - self.downsampling = downsampling - - self.shortcut = ( - nn.Sequential( - nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.expanded_channels, - kernel_size=1, - stride=self.downsampling, - bias=False, - ), - nn.BatchNorm2d(self.expanded_channels), - ) - if self.apply_shortcut - else None - ) - - @property - def expanded_channels(self) -> int: - """Computes the expanded output channels.""" - return self.out_channels * self.expansion - - @property - def apply_shortcut(self) -> bool: - """Check if shortcut should be applied.""" - return self.in_channels != self.expanded_channels - - -class BasicBlock(ResidualBlock): - """Basic ResNet block.""" - - expansion = 1 - - def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: - super().__init__(in_channels, out_channels, *args, **kwargs) - self.blocks = nn.Sequential( - conv_bn( - in_channels=self.in_channels, - out_channels=self.out_channels, - bias=False, - stride=self.downsampling, - ), - self.activation_fn, - conv_bn( - in_channels=self.out_channels, - out_channels=self.expanded_channels, - bias=False, - ), - ) - - -class BottleNeckBlock(ResidualBlock): - """Bottleneck block to increase depth while minimizing parameter size.""" - - expansion = 4 - - def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: - super().__init__(in_channels, out_channels, *args, **kwargs) - self.blocks = nn.Sequential( - conv_bn( - in_channels=self.in_channels, - out_channels=self.out_channels, - kernel_size=1, - ), - self.activation_fn, - conv_bn( - in_channels=self.out_channels, - out_channels=self.out_channels, - kernel_size=3, - stride=self.downsampling, - ), - self.activation_fn, - conv_bn( - in_channels=self.out_channels, - out_channels=self.expanded_channels, - kernel_size=1, - ), - ) - - -class ResidualLayer(nn.Module): - """ResNet layer.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - block: BasicBlock = BasicBlock, - num_blocks: int = 1, - *args, - **kwargs - ) -> None: - super().__init__() - downsampling = 2 if in_channels != out_channels else 1 - self.blocks = nn.Sequential( - block( - in_channels, out_channels, *args, **kwargs, downsampling=downsampling - ), - *[ - block( - out_channels * block.expansion, - out_channels, - downsampling=1, - *args, - **kwargs - ) - for _ in range(num_blocks - 1) - ] - ) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - x = self.blocks(x) - return x - - -class ResidualNetworkEncoder(nn.Module): - """Encoder network.""" - - def __init__( - self, - in_channels: int = 1, - block_sizes: Union[int, List[int]] = (32, 64), - depths: Union[int, List[int]] = (2, 2), - activation: str = "relu", - block: Type[nn.Module] = BasicBlock, - levels: int = 1, - *args, - **kwargs - ) -> None: - super().__init__() - self.block_sizes = ( - block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels - ) - self.depths = depths if isinstance(depths, list) else [depths] * levels - self.activation = activation - self.gate = nn.Sequential( - nn.Conv2d( - in_channels=in_channels, - out_channels=self.block_sizes[0], - kernel_size=7, - stride=2, - padding=1, - bias=False, - ), - nn.BatchNorm2d(self.block_sizes[0]), - activation_function(self.activation), - # nn.MaxPool2d(kernel_size=2, stride=2, padding=1), - ) - - self.blocks = self._configure_blocks(block) - - def _configure_blocks( - self, block: Type[nn.Module], *args, **kwargs - ) -> nn.Sequential: - channels = [self.block_sizes[0]] + list( - zip(self.block_sizes, self.block_sizes[1:]) - ) - blocks = [ - ResidualLayer( - in_channels=channels[0], - out_channels=channels[0], - num_blocks=self.depths[0], - block=block, - activation=self.activation, - *args, - **kwargs - ) - ] - blocks += [ - ResidualLayer( - in_channels=in_channels * block.expansion, - out_channels=out_channels, - num_blocks=num_blocks, - block=block, - activation=self.activation, - *args, - **kwargs - ) - for (in_channels, out_channels), num_blocks in zip( - channels[1:], self.depths[1:] - ) - ] - - return nn.Sequential(*blocks) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - # If batch dimenstion is missing, it needs to be added. - if len(x.shape) == 3: - x = x.unsqueeze(0) - x = self.gate(x) - x = self.blocks(x) - return x - - -class ResidualNetworkDecoder(nn.Module): - """Classification head.""" - - def __init__(self, in_features: int, num_classes: int = 80) -> None: - super().__init__() - self.decoder = nn.Sequential( - Reduce("b c h w -> b c", "mean"), - nn.Linear(in_features=in_features, out_features=num_classes), - ) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - return self.decoder(x) - - -class ResidualNetwork(nn.Module): - """Full residual network.""" - - def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None: - super().__init__() - self.encoder = ResidualNetworkEncoder(in_channels, *args, **kwargs) - self.decoder = ResidualNetworkDecoder( - in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels, - num_classes=num_classes, - ) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - x = self.encoder(x) - x = self.decoder(x) - return x diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index ac75d2f..e1324af 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -1,94 +1,73 @@ """Implementes the attention module for the transformer.""" from typing import Optional, Tuple -from einops import rearrange +from einops.layers.torch import Rearrange import numpy as np import torch from torch import nn from torch import Tensor +import torch.nn.functional as F +from text_recognizer.networks.transformer.rotary_embedding import apply_rotary_pos_emb -class MultiHeadAttention(nn.Module): - """Implementation of multihead attention.""" +class Attention(nn.Module): def __init__( - self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0 + self, + dim: int, + num_heads: int, + dim_head: int = 64, + dropout_rate: float = 0.0, + causal: bool = False, ) -> None: - super().__init__() - self.hidden_dim = hidden_dim + self.scale = dim ** -0.5 self.num_heads = num_heads - self.fc_q = nn.Linear( - in_features=hidden_dim, out_features=hidden_dim, bias=False - ) - self.fc_k = nn.Linear( - in_features=hidden_dim, out_features=hidden_dim, bias=False - ) - self.fc_v = nn.Linear( - in_features=hidden_dim, out_features=hidden_dim, bias=False - ) - self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) - - self._init_weights() + self.causal = causal + inner_dim = dim * dim_head - self.dropout = nn.Dropout(p=dropout_rate) - - def _init_weights(self) -> None: - nn.init.normal_( - self.fc_q.weight, - mean=0, - std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), - ) - nn.init.normal_( - self.fc_k.weight, - mean=0, - std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), + # Attnetion + self.qkv_fn = nn.Sequential( + nn.Linear(dim, 3 * inner_dim, bias=False), + Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=self.num_heads), ) - nn.init.normal_( - self.fc_v.weight, - mean=0, - std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), - ) - nn.init.xavier_normal_(self.fc_out.weight) + self.dropout = nn.Dropout(dropout_rate) + self.attn_fn = F.softmax - @staticmethod - def scaled_dot_product_attention( - query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None - ) -> Tensor: - """Calculates the scaled dot product attention.""" + # Feedforward + self.proj = nn.Linear(inner_dim, dim) - # Compute the energy. - energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt( - query.shape[-1] - ) - - # If we have a mask for padding some inputs. - if mask is not None: - energy = energy.masked_fill(mask == 0, -np.inf) - - # Compute the attention from the energy. - attention = torch.softmax(energy, dim=3) + @staticmethod + def _apply_rotary_emb( + q: Tensor, k: Tensor, rotary_pos_emb: Tensor + ) -> Tuple[Tensor, Tensor]: + l = rotary_pos_emb.shape[-1] + (ql, qr), (kl, kr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k)) + ql, kl = apply_rotary_pos_emb(ql, kl, rotary_pos_emb) + q = torch.cat((ql, qr), dim=-1) + k = torch.cat((kl, kr), dim=-1) + return q, k - out = torch.einsum("bhlt,bhtv->bhlv", [attention, value]) - out = rearrange(out, "b head l v -> b l (head v)") - return out, attention + def _cross_attention(self) -> Tensor: + pass def forward( - self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None + self, + x: Tensor, + context: Optional[Tensor], + mask: Optional[Tensor], + context_mask: Optional[Tensor], + rotary_pos_emb: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: - """Forward pass for computing the multihead attention.""" - # Get the query, key, and value tensor. - query = rearrange( - self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads - ) - key = rearrange( - self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads - ) - value = rearrange( - self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads + q, k, v = self.qkv_fn(x) + q, k = ( + self._apply_rotary_emb(q, k, rotary_pos_emb) + if rotary_pos_emb is not None + else q, + k, ) - out, attention = self.scaled_dot_product_attention(query, key, value, mask) + if any(x is not None for x in (mask, context_mask)): + pass - out = self.fc_out(out) - out = self.dropout(out) - return out, attention + # Compute the attention + energy = (q @ k.transpose(-2, -1)) * self.scale diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py index 99a5291..9160876 100644 --- a/text_recognizer/networks/transformer/norm.py +++ b/text_recognizer/networks/transformer/norm.py @@ -20,3 +20,16 @@ class Rezero(nn.Module): def forward(self, x: Tensor, **kwargs: Dict) -> Tensor: x, *rest = self.fn(x, **kwargs) return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1.0e-5) -> None: + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x: Tensor) -> Tensor: + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) self.g + diff --git a/text_recognizer/networks/wide_resnet.py b/text_recognizer/networks/wide_resnet.py deleted file mode 100644 index b767778..0000000 --- a/text_recognizer/networks/wide_resnet.py +++ /dev/null @@ -1,221 +0,0 @@ -"""Wide Residual CNN.""" -from functools import partial -from typing import Callable, Dict, List, Optional, Type, Union - -from einops.layers.torch import Reduce -import numpy as np -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function - - -def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: - """Helper function for a 3x3 2d convolution.""" - return nn.Conv2d( - in_channels=in_planes, - out_channels=out_planes, - kernel_size=3, - stride=stride, - padding=1, - bias=False, - ) - - -def conv_init(module: Type[nn.Module]) -> None: - """Initializes the weights for convolution and batchnorms.""" - classname = module.__class__.__name__ - if classname.find("Conv") != -1: - nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2)) - nn.init.constant_(module.bias, 0) - elif classname.find("BatchNorm") != -1: - nn.init.constant_(module.weight, 1) - nn.init.constant_(module.bias, 0) - - -class WideBlock(nn.Module): - """Block used in WideResNet.""" - - def __init__( - self, - in_planes: int, - out_planes: int, - dropout_rate: float, - stride: int = 1, - activation: str = "relu", - ) -> None: - super().__init__() - self.in_planes = in_planes - self.out_planes = out_planes - self.dropout_rate = dropout_rate - self.stride = stride - self.activation = activation_function(activation) - - # Build blocks. - self.blocks = nn.Sequential( - nn.BatchNorm2d(self.in_planes), - self.activation, - conv3x3(in_planes=self.in_planes, out_planes=self.out_planes), - nn.Dropout(p=self.dropout_rate), - nn.BatchNorm2d(self.out_planes), - self.activation, - conv3x3( - in_planes=self.out_planes, - out_planes=self.out_planes, - stride=self.stride, - ), - ) - - self.shortcut = ( - nn.Sequential( - nn.Conv2d( - in_channels=self.in_planes, - out_channels=self.out_planes, - kernel_size=1, - stride=self.stride, - bias=False, - ), - ) - if self._apply_shortcut - else None - ) - - @property - def _apply_shortcut(self) -> bool: - """If shortcut should be applied or not.""" - return self.stride != 1 or self.in_planes != self.out_planes - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - residual = x - if self._apply_shortcut: - residual = self.shortcut(x) - x = self.blocks(x) - x += residual - return x - - -class WideResidualNetwork(nn.Module): - """WideResNet for character predictions. - - Can be used for classification or encoding of images to a latent vector. - - """ - - def __init__( - self, - in_channels: int = 1, - in_planes: int = 16, - num_classes: int = 80, - depth: int = 16, - width_factor: int = 10, - dropout_rate: float = 0.0, - num_layers: int = 3, - block: Type[nn.Module] = WideBlock, - num_stages: Optional[List[int]] = None, - activation: str = "relu", - use_decoder: bool = True, - ) -> None: - """The initialization of the WideResNet. - - Args: - in_channels (int): Number of input channels. Defaults to 1. - in_planes (int): Number of channels to use in the first output kernel. Defaults to 16. - num_classes (int): Number of classes. Defaults to 80. - depth (int): Set the number of blocks to use. Defaults to 16. - width_factor (int): Factor for scaling the number of channels in the network. Defaults to 10. - dropout_rate (float): The dropout rate. Defaults to 0.0. - num_layers (int): Number of layers of blocks. Defaults to 3. - block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock. - num_stages (List[int]): If given, will use these channel values. Defaults to None. - activation (str): Name of the activation to use. Defaults to "relu". - use_decoder (bool): If True, the network output character predictions, if False, the network outputs a - latent vector. Defaults to True. - - Raises: - RuntimeError: If the depth is not of the size `6n+4`. - - """ - - super().__init__() - if (depth - 4) % 6 != 0: - raise RuntimeError("Wide-resnet depth should be 6n+4") - self.in_channels = in_channels - self.in_planes = in_planes - self.num_classes = num_classes - self.num_blocks = (depth - 4) // 6 - self.width_factor = width_factor - self.num_layers = num_layers - self.block = block - self.dropout_rate = dropout_rate - self.activation = activation_function(activation) - - if num_stages is None: - self.num_stages = [self.in_planes] + [ - self.in_planes * 2 ** n * self.width_factor - for n in range(self.num_layers) - ] - else: - self.num_stages = [self.in_planes] + num_stages - - self.num_stages = list(zip(self.num_stages, self.num_stages[1:])) - self.strides = [1] + [2] * (self.num_layers - 1) - - self.encoder = nn.Sequential( - conv3x3(in_planes=self.in_channels, out_planes=self.in_planes), - *[ - self._configure_wide_layer( - in_planes=in_planes, - out_planes=out_planes, - stride=stride, - activation=activation, - ) - for (in_planes, out_planes), stride in zip( - self.num_stages, self.strides - ) - ], - ) - - self.decoder = ( - nn.Sequential( - nn.BatchNorm2d(self.num_stages[-1][-1], momentum=0.8), - self.activation, - Reduce("b c h w -> b c", "mean"), - nn.Linear( - in_features=self.num_stages[-1][-1], out_features=self.num_classes - ), - ) - if use_decoder - else None - ) - - # self.apply(conv_init) - - def _configure_wide_layer( - self, in_planes: int, out_planes: int, stride: int, activation: str - ) -> List: - strides = [stride] + [1] * (self.num_blocks - 1) - planes = [out_planes] * len(strides) - planes = [(in_planes, out_planes)] + list(zip(planes, planes[1:])) - return nn.Sequential( - *[ - self.block( - in_planes=in_planes, - out_planes=out_planes, - dropout_rate=self.dropout_rate, - stride=stride, - activation=activation, - ) - for (in_planes, out_planes), stride in zip(planes, strides) - ] - ) - - def forward(self, x: Tensor) -> Tensor: - """Feedforward pass.""" - if len(x.shape) < 4: - x = x[(None,) * int(4 - len(x.shape))] - x = self.encoder(x) - if self.decoder is not None: - x = self.decoder(x) - return x -- cgit v1.2.3-70-g09d2