summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-01 23:53:50 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-01 23:53:50 +0200
commit58ae7154aa945cfe5a46592cc1dfb28f0a4e51b3 (patch)
treec89c1b1a4cc1a499900f2700ab09e8535e2cfe99 /text_recognizer/networks/encoders
parent7ae1f8f9654dcea0a9a22310ac0665a5d3202f0f (diff)
Working on new attention module
Diffstat (limited to 'text_recognizer/networks/encoders')
-rw-r--r--text_recognizer/networks/encoders/__init__.py2
-rw-r--r--text_recognizer/networks/encoders/efficientnet.py145
-rw-r--r--text_recognizer/networks/encoders/residual_network.py310
-rw-r--r--text_recognizer/networks/encoders/wide_resnet.py221
4 files changed, 678 insertions, 0 deletions
diff --git a/text_recognizer/networks/encoders/__init__.py b/text_recognizer/networks/encoders/__init__.py
new file mode 100644
index 0000000..25aed0e
--- /dev/null
+++ b/text_recognizer/networks/encoders/__init__.py
@@ -0,0 +1,2 @@
+"""Vision backbones."""
+from .efficientnet import EfficientNet
diff --git a/text_recognizer/networks/encoders/efficientnet.py b/text_recognizer/networks/encoders/efficientnet.py
new file mode 100644
index 0000000..61dea77
--- /dev/null
+++ b/text_recognizer/networks/encoders/efficientnet.py
@@ -0,0 +1,145 @@
+"""Efficient net b0 implementation."""
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class ConvNorm(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int,
+ padding: int,
+ groups: int = 1,
+ ) -> None:
+ super().__init__()
+ self.block = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ bias=False,
+ ),
+ nn.BatchNorm2d(num_features=out_channels),
+ nn.SiLU(inplace=True),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.block(x)
+
+
+class SqueezeExcite(nn.Module):
+ def __init__(self, in_channels: int, reduce_dim: int) -> None:
+ super().__init__()
+ self.se = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1), # [C, H, W] -> [C, 1, 1]
+ nn.Conv2d(in_channels=in_channels, out_channels=reduce_dim, kernel_size=1),
+ nn.SiLU(),
+ nn.Conv2d(in_channels=reduce_dim, out_channels=in_channels, kernel_size=1),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x * self.se(x)
+
+
+class InvertedResidulaBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int,
+ padding: int,
+ expand_ratio: float,
+ reduction: int = 4,
+ survival_prob: float = 0.8,
+ ) -> None:
+ super().__init__()
+ self.survival_prob = survival_prob
+ self.use_residual = in_channels == out_channels and stride == 1
+ hidden_dim = in_channels * expand_ratio
+ self.expand = in_channels != hidden_dim
+ reduce_dim = in_channels // reduction
+
+ if self.expand:
+ self.expand_conv = ConvNorm(
+ in_channels, hidden_dim, kernel_size=3, stride=1, padding=1
+ )
+
+ self.conv = nn.Sequential(
+ ConvNorm(
+ hidden_dim, hidden_dim, kernel_size, stride, padding, groups=hidden_dim
+ ),
+ SqueezeExcite(hidden_dim, reduce_dim),
+ nn.Conv2d(
+ in_channels=hidden_dim,
+ out_channels=out_channels,
+ kernel_size=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(num_features=out_channels),
+ )
+
+ def stochastic_depth(self, x: Tensor) -> Tensor:
+ if not self.training:
+ return x
+
+ binary_tensor = (
+ torch.rand(x.shape[0], 1, 1, 1, device=x.device) < self.survival_prob
+ )
+ return torch.div(x, self.survival_prob) * binary_tensor
+
+ def forward(self, x: Tensor) -> Tensor:
+ out = self.expand_conv(x) if self.expand else x
+ if self.use_residual:
+ return self.stochastic_depth(self.conv(out)) + x
+ return self.conv(out)
+
+
+class EfficientNet(nn.Module):
+ """Efficient net b0 backbone."""
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.base_model = [
+ # expand_ratio, channels, repeats, stride, kernel_size
+ [1, 16, 1, 1, 3],
+ [6, 24, 2, 2, 3],
+ [6, 40, 2, 2, 5],
+ [6, 80, 3, 2, 3],
+ [6, 112, 3, 1, 5],
+ [6, 192, 4, 2, 5],
+ [6, 320, 1, 1, 3],
+ ]
+
+ self.backbone = self._build_b0()
+
+ def _build_b0(self) -> nn.Sequential:
+ in_channels = 32
+ layers = [ConvNorm(1, in_channels, 3, stride=2, padding=1)]
+
+ for expand_ratio, out_channels, repeats, stride, kernel_size in self.base_model:
+ for i in range(repeats):
+ layers.append(
+ InvertedResidulaBlock(
+ in_channels,
+ out_channels,
+ expand_ratio=expand_ratio,
+ stride=stride if i == 0 else 1,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ )
+ )
+ in_channels = out_channels
+ layers.append(ConvNorm(in_channels, 256, kernel_size=1, stride=1, padding=0))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.backbone(x)
diff --git a/text_recognizer/networks/encoders/residual_network.py b/text_recognizer/networks/encoders/residual_network.py
new file mode 100644
index 0000000..c33f419
--- /dev/null
+++ b/text_recognizer/networks/encoders/residual_network.py
@@ -0,0 +1,310 @@
+"""Residual CNN."""
+from functools import partial
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from einops.layers.torch import Rearrange, Reduce
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+
+
+class Conv2dAuto(nn.Conv2d):
+ """Convolution with auto padding based on kernel size."""
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
+
+
+def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential:
+ """3x3 convolution with batch norm."""
+ conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,)
+ return nn.Sequential(
+ conv3x3(in_channels, out_channels, *args, **kwargs),
+ nn.BatchNorm2d(out_channels),
+ )
+
+
+class IdentityBlock(nn.Module):
+ """Residual with identity block."""
+
+ def __init__(
+ self, in_channels: int, out_channels: int, activation: str = "relu"
+ ) -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.blocks = nn.Identity()
+ self.activation_fn = activation_function(activation)
+ self.shortcut = nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ residual = x
+ if self.apply_shortcut:
+ residual = self.shortcut(x)
+ x = self.blocks(x)
+ x += residual
+ x = self.activation_fn(x)
+ return x
+
+ @property
+ def apply_shortcut(self) -> bool:
+ """Check if shortcut should be applied."""
+ return self.in_channels != self.out_channels
+
+
+class ResidualBlock(IdentityBlock):
+ """Residual with nonlinear shortcut."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ expansion: int = 1,
+ downsampling: int = 1,
+ *args,
+ **kwargs
+ ) -> None:
+ """Short summary.
+
+ Args:
+ in_channels (int): Number of in channels.
+ out_channels (int): umber of out channels.
+ expansion (int): Expansion factor of the out channels. Defaults to 1.
+ downsampling (int): Downsampling factor used in stride. Defaults to 1.
+ *args (type): Extra arguments.
+ **kwargs (type): Extra key value arguments.
+
+ """
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ self.expansion = expansion
+ self.downsampling = downsampling
+
+ self.shortcut = (
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.expanded_channels,
+ kernel_size=1,
+ stride=self.downsampling,
+ bias=False,
+ ),
+ nn.BatchNorm2d(self.expanded_channels),
+ )
+ if self.apply_shortcut
+ else None
+ )
+
+ @property
+ def expanded_channels(self) -> int:
+ """Computes the expanded output channels."""
+ return self.out_channels * self.expansion
+
+ @property
+ def apply_shortcut(self) -> bool:
+ """Check if shortcut should be applied."""
+ return self.in_channels != self.expanded_channels
+
+
+class BasicBlock(ResidualBlock):
+ """Basic ResNet block."""
+
+ expansion = 1
+
+ def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None:
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ self.blocks = nn.Sequential(
+ conv_bn(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ bias=False,
+ stride=self.downsampling,
+ ),
+ self.activation_fn,
+ conv_bn(
+ in_channels=self.out_channels,
+ out_channels=self.expanded_channels,
+ bias=False,
+ ),
+ )
+
+
+class BottleNeckBlock(ResidualBlock):
+ """Bottleneck block to increase depth while minimizing parameter size."""
+
+ expansion = 4
+
+ def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None:
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ self.blocks = nn.Sequential(
+ conv_bn(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ kernel_size=1,
+ ),
+ self.activation_fn,
+ conv_bn(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels,
+ kernel_size=3,
+ stride=self.downsampling,
+ ),
+ self.activation_fn,
+ conv_bn(
+ in_channels=self.out_channels,
+ out_channels=self.expanded_channels,
+ kernel_size=1,
+ ),
+ )
+
+
+class ResidualLayer(nn.Module):
+ """ResNet layer."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ block: BasicBlock = BasicBlock,
+ num_blocks: int = 1,
+ *args,
+ **kwargs
+ ) -> None:
+ super().__init__()
+ downsampling = 2 if in_channels != out_channels else 1
+ self.blocks = nn.Sequential(
+ block(
+ in_channels, out_channels, *args, **kwargs, downsampling=downsampling
+ ),
+ *[
+ block(
+ out_channels * block.expansion,
+ out_channels,
+ downsampling=1,
+ *args,
+ **kwargs
+ )
+ for _ in range(num_blocks - 1)
+ ]
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ x = self.blocks(x)
+ return x
+
+
+class ResidualNetworkEncoder(nn.Module):
+ """Encoder network."""
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ block_sizes: Union[int, List[int]] = (32, 64),
+ depths: Union[int, List[int]] = (2, 2),
+ activation: str = "relu",
+ block: Type[nn.Module] = BasicBlock,
+ levels: int = 1,
+ *args,
+ **kwargs
+ ) -> None:
+ super().__init__()
+ self.block_sizes = (
+ block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels
+ )
+ self.depths = depths if isinstance(depths, list) else [depths] * levels
+ self.activation = activation
+ self.gate = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=self.block_sizes[0],
+ kernel_size=7,
+ stride=2,
+ padding=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(self.block_sizes[0]),
+ activation_function(self.activation),
+ # nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
+ )
+
+ self.blocks = self._configure_blocks(block)
+
+ def _configure_blocks(
+ self, block: Type[nn.Module], *args, **kwargs
+ ) -> nn.Sequential:
+ channels = [self.block_sizes[0]] + list(
+ zip(self.block_sizes, self.block_sizes[1:])
+ )
+ blocks = [
+ ResidualLayer(
+ in_channels=channels[0],
+ out_channels=channels[0],
+ num_blocks=self.depths[0],
+ block=block,
+ activation=self.activation,
+ *args,
+ **kwargs
+ )
+ ]
+ blocks += [
+ ResidualLayer(
+ in_channels=in_channels * block.expansion,
+ out_channels=out_channels,
+ num_blocks=num_blocks,
+ block=block,
+ activation=self.activation,
+ *args,
+ **kwargs
+ )
+ for (in_channels, out_channels), num_blocks in zip(
+ channels[1:], self.depths[1:]
+ )
+ ]
+
+ return nn.Sequential(*blocks)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) == 3:
+ x = x.unsqueeze(0)
+ x = self.gate(x)
+ x = self.blocks(x)
+ return x
+
+
+class ResidualNetworkDecoder(nn.Module):
+ """Classification head."""
+
+ def __init__(self, in_features: int, num_classes: int = 80) -> None:
+ super().__init__()
+ self.decoder = nn.Sequential(
+ Reduce("b c h w -> b c", "mean"),
+ nn.Linear(in_features=in_features, out_features=num_classes),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ return self.decoder(x)
+
+
+class ResidualNetwork(nn.Module):
+ """Full residual network."""
+
+ def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None:
+ super().__init__()
+ self.encoder = ResidualNetworkEncoder(in_channels, *args, **kwargs)
+ self.decoder = ResidualNetworkDecoder(
+ in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels,
+ num_classes=num_classes,
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ x = self.encoder(x)
+ x = self.decoder(x)
+ return x
diff --git a/text_recognizer/networks/encoders/wide_resnet.py b/text_recognizer/networks/encoders/wide_resnet.py
new file mode 100644
index 0000000..b767778
--- /dev/null
+++ b/text_recognizer/networks/encoders/wide_resnet.py
@@ -0,0 +1,221 @@
+"""Wide Residual CNN."""
+from functools import partial
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from einops.layers.torch import Reduce
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
+ """Helper function for a 3x3 2d convolution."""
+ return nn.Conv2d(
+ in_channels=in_planes,
+ out_channels=out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False,
+ )
+
+
+def conv_init(module: Type[nn.Module]) -> None:
+ """Initializes the weights for convolution and batchnorms."""
+ classname = module.__class__.__name__
+ if classname.find("Conv") != -1:
+ nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2))
+ nn.init.constant_(module.bias, 0)
+ elif classname.find("BatchNorm") != -1:
+ nn.init.constant_(module.weight, 1)
+ nn.init.constant_(module.bias, 0)
+
+
+class WideBlock(nn.Module):
+ """Block used in WideResNet."""
+
+ def __init__(
+ self,
+ in_planes: int,
+ out_planes: int,
+ dropout_rate: float,
+ stride: int = 1,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.in_planes = in_planes
+ self.out_planes = out_planes
+ self.dropout_rate = dropout_rate
+ self.stride = stride
+ self.activation = activation_function(activation)
+
+ # Build blocks.
+ self.blocks = nn.Sequential(
+ nn.BatchNorm2d(self.in_planes),
+ self.activation,
+ conv3x3(in_planes=self.in_planes, out_planes=self.out_planes),
+ nn.Dropout(p=self.dropout_rate),
+ nn.BatchNorm2d(self.out_planes),
+ self.activation,
+ conv3x3(
+ in_planes=self.out_planes,
+ out_planes=self.out_planes,
+ stride=self.stride,
+ ),
+ )
+
+ self.shortcut = (
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.in_planes,
+ out_channels=self.out_planes,
+ kernel_size=1,
+ stride=self.stride,
+ bias=False,
+ ),
+ )
+ if self._apply_shortcut
+ else None
+ )
+
+ @property
+ def _apply_shortcut(self) -> bool:
+ """If shortcut should be applied or not."""
+ return self.stride != 1 or self.in_planes != self.out_planes
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ residual = x
+ if self._apply_shortcut:
+ residual = self.shortcut(x)
+ x = self.blocks(x)
+ x += residual
+ return x
+
+
+class WideResidualNetwork(nn.Module):
+ """WideResNet for character predictions.
+
+ Can be used for classification or encoding of images to a latent vector.
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ in_planes: int = 16,
+ num_classes: int = 80,
+ depth: int = 16,
+ width_factor: int = 10,
+ dropout_rate: float = 0.0,
+ num_layers: int = 3,
+ block: Type[nn.Module] = WideBlock,
+ num_stages: Optional[List[int]] = None,
+ activation: str = "relu",
+ use_decoder: bool = True,
+ ) -> None:
+ """The initialization of the WideResNet.
+
+ Args:
+ in_channels (int): Number of input channels. Defaults to 1.
+ in_planes (int): Number of channels to use in the first output kernel. Defaults to 16.
+ num_classes (int): Number of classes. Defaults to 80.
+ depth (int): Set the number of blocks to use. Defaults to 16.
+ width_factor (int): Factor for scaling the number of channels in the network. Defaults to 10.
+ dropout_rate (float): The dropout rate. Defaults to 0.0.
+ num_layers (int): Number of layers of blocks. Defaults to 3.
+ block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock.
+ num_stages (List[int]): If given, will use these channel values. Defaults to None.
+ activation (str): Name of the activation to use. Defaults to "relu".
+ use_decoder (bool): If True, the network output character predictions, if False, the network outputs a
+ latent vector. Defaults to True.
+
+ Raises:
+ RuntimeError: If the depth is not of the size `6n+4`.
+
+ """
+
+ super().__init__()
+ if (depth - 4) % 6 != 0:
+ raise RuntimeError("Wide-resnet depth should be 6n+4")
+ self.in_channels = in_channels
+ self.in_planes = in_planes
+ self.num_classes = num_classes
+ self.num_blocks = (depth - 4) // 6
+ self.width_factor = width_factor
+ self.num_layers = num_layers
+ self.block = block
+ self.dropout_rate = dropout_rate
+ self.activation = activation_function(activation)
+
+ if num_stages is None:
+ self.num_stages = [self.in_planes] + [
+ self.in_planes * 2 ** n * self.width_factor
+ for n in range(self.num_layers)
+ ]
+ else:
+ self.num_stages = [self.in_planes] + num_stages
+
+ self.num_stages = list(zip(self.num_stages, self.num_stages[1:]))
+ self.strides = [1] + [2] * (self.num_layers - 1)
+
+ self.encoder = nn.Sequential(
+ conv3x3(in_planes=self.in_channels, out_planes=self.in_planes),
+ *[
+ self._configure_wide_layer(
+ in_planes=in_planes,
+ out_planes=out_planes,
+ stride=stride,
+ activation=activation,
+ )
+ for (in_planes, out_planes), stride in zip(
+ self.num_stages, self.strides
+ )
+ ],
+ )
+
+ self.decoder = (
+ nn.Sequential(
+ nn.BatchNorm2d(self.num_stages[-1][-1], momentum=0.8),
+ self.activation,
+ Reduce("b c h w -> b c", "mean"),
+ nn.Linear(
+ in_features=self.num_stages[-1][-1], out_features=self.num_classes
+ ),
+ )
+ if use_decoder
+ else None
+ )
+
+ # self.apply(conv_init)
+
+ def _configure_wide_layer(
+ self, in_planes: int, out_planes: int, stride: int, activation: str
+ ) -> List:
+ strides = [stride] + [1] * (self.num_blocks - 1)
+ planes = [out_planes] * len(strides)
+ planes = [(in_planes, out_planes)] + list(zip(planes, planes[1:]))
+ return nn.Sequential(
+ *[
+ self.block(
+ in_planes=in_planes,
+ out_planes=out_planes,
+ dropout_rate=self.dropout_rate,
+ stride=stride,
+ activation=activation,
+ )
+ for (in_planes, out_planes), stride in zip(planes, strides)
+ ]
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Feedforward pass."""
+ if len(x.shape) < 4:
+ x = x[(None,) * int(4 - len(x.shape))]
+ x = self.encoder(x)
+ if self.decoder is not None:
+ x = self.decoder(x)
+ return x