summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders/wide_resnet.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-01 23:53:50 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-01 23:53:50 +0200
commit58ae7154aa945cfe5a46592cc1dfb28f0a4e51b3 (patch)
treec89c1b1a4cc1a499900f2700ab09e8535e2cfe99 /text_recognizer/networks/encoders/wide_resnet.py
parent7ae1f8f9654dcea0a9a22310ac0665a5d3202f0f (diff)
Working on new attention module
Diffstat (limited to 'text_recognizer/networks/encoders/wide_resnet.py')
-rw-r--r--text_recognizer/networks/encoders/wide_resnet.py221
1 files changed, 221 insertions, 0 deletions
diff --git a/text_recognizer/networks/encoders/wide_resnet.py b/text_recognizer/networks/encoders/wide_resnet.py
new file mode 100644
index 0000000..b767778
--- /dev/null
+++ b/text_recognizer/networks/encoders/wide_resnet.py
@@ -0,0 +1,221 @@
+"""Wide Residual CNN."""
+from functools import partial
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from einops.layers.torch import Reduce
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
+ """Helper function for a 3x3 2d convolution."""
+ return nn.Conv2d(
+ in_channels=in_planes,
+ out_channels=out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False,
+ )
+
+
+def conv_init(module: Type[nn.Module]) -> None:
+ """Initializes the weights for convolution and batchnorms."""
+ classname = module.__class__.__name__
+ if classname.find("Conv") != -1:
+ nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2))
+ nn.init.constant_(module.bias, 0)
+ elif classname.find("BatchNorm") != -1:
+ nn.init.constant_(module.weight, 1)
+ nn.init.constant_(module.bias, 0)
+
+
+class WideBlock(nn.Module):
+ """Block used in WideResNet."""
+
+ def __init__(
+ self,
+ in_planes: int,
+ out_planes: int,
+ dropout_rate: float,
+ stride: int = 1,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.in_planes = in_planes
+ self.out_planes = out_planes
+ self.dropout_rate = dropout_rate
+ self.stride = stride
+ self.activation = activation_function(activation)
+
+ # Build blocks.
+ self.blocks = nn.Sequential(
+ nn.BatchNorm2d(self.in_planes),
+ self.activation,
+ conv3x3(in_planes=self.in_planes, out_planes=self.out_planes),
+ nn.Dropout(p=self.dropout_rate),
+ nn.BatchNorm2d(self.out_planes),
+ self.activation,
+ conv3x3(
+ in_planes=self.out_planes,
+ out_planes=self.out_planes,
+ stride=self.stride,
+ ),
+ )
+
+ self.shortcut = (
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.in_planes,
+ out_channels=self.out_planes,
+ kernel_size=1,
+ stride=self.stride,
+ bias=False,
+ ),
+ )
+ if self._apply_shortcut
+ else None
+ )
+
+ @property
+ def _apply_shortcut(self) -> bool:
+ """If shortcut should be applied or not."""
+ return self.stride != 1 or self.in_planes != self.out_planes
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ residual = x
+ if self._apply_shortcut:
+ residual = self.shortcut(x)
+ x = self.blocks(x)
+ x += residual
+ return x
+
+
+class WideResidualNetwork(nn.Module):
+ """WideResNet for character predictions.
+
+ Can be used for classification or encoding of images to a latent vector.
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ in_planes: int = 16,
+ num_classes: int = 80,
+ depth: int = 16,
+ width_factor: int = 10,
+ dropout_rate: float = 0.0,
+ num_layers: int = 3,
+ block: Type[nn.Module] = WideBlock,
+ num_stages: Optional[List[int]] = None,
+ activation: str = "relu",
+ use_decoder: bool = True,
+ ) -> None:
+ """The initialization of the WideResNet.
+
+ Args:
+ in_channels (int): Number of input channels. Defaults to 1.
+ in_planes (int): Number of channels to use in the first output kernel. Defaults to 16.
+ num_classes (int): Number of classes. Defaults to 80.
+ depth (int): Set the number of blocks to use. Defaults to 16.
+ width_factor (int): Factor for scaling the number of channels in the network. Defaults to 10.
+ dropout_rate (float): The dropout rate. Defaults to 0.0.
+ num_layers (int): Number of layers of blocks. Defaults to 3.
+ block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock.
+ num_stages (List[int]): If given, will use these channel values. Defaults to None.
+ activation (str): Name of the activation to use. Defaults to "relu".
+ use_decoder (bool): If True, the network output character predictions, if False, the network outputs a
+ latent vector. Defaults to True.
+
+ Raises:
+ RuntimeError: If the depth is not of the size `6n+4`.
+
+ """
+
+ super().__init__()
+ if (depth - 4) % 6 != 0:
+ raise RuntimeError("Wide-resnet depth should be 6n+4")
+ self.in_channels = in_channels
+ self.in_planes = in_planes
+ self.num_classes = num_classes
+ self.num_blocks = (depth - 4) // 6
+ self.width_factor = width_factor
+ self.num_layers = num_layers
+ self.block = block
+ self.dropout_rate = dropout_rate
+ self.activation = activation_function(activation)
+
+ if num_stages is None:
+ self.num_stages = [self.in_planes] + [
+ self.in_planes * 2 ** n * self.width_factor
+ for n in range(self.num_layers)
+ ]
+ else:
+ self.num_stages = [self.in_planes] + num_stages
+
+ self.num_stages = list(zip(self.num_stages, self.num_stages[1:]))
+ self.strides = [1] + [2] * (self.num_layers - 1)
+
+ self.encoder = nn.Sequential(
+ conv3x3(in_planes=self.in_channels, out_planes=self.in_planes),
+ *[
+ self._configure_wide_layer(
+ in_planes=in_planes,
+ out_planes=out_planes,
+ stride=stride,
+ activation=activation,
+ )
+ for (in_planes, out_planes), stride in zip(
+ self.num_stages, self.strides
+ )
+ ],
+ )
+
+ self.decoder = (
+ nn.Sequential(
+ nn.BatchNorm2d(self.num_stages[-1][-1], momentum=0.8),
+ self.activation,
+ Reduce("b c h w -> b c", "mean"),
+ nn.Linear(
+ in_features=self.num_stages[-1][-1], out_features=self.num_classes
+ ),
+ )
+ if use_decoder
+ else None
+ )
+
+ # self.apply(conv_init)
+
+ def _configure_wide_layer(
+ self, in_planes: int, out_planes: int, stride: int, activation: str
+ ) -> List:
+ strides = [stride] + [1] * (self.num_blocks - 1)
+ planes = [out_planes] * len(strides)
+ planes = [(in_planes, out_planes)] + list(zip(planes, planes[1:]))
+ return nn.Sequential(
+ *[
+ self.block(
+ in_planes=in_planes,
+ out_planes=out_planes,
+ dropout_rate=self.dropout_rate,
+ stride=stride,
+ activation=activation,
+ )
+ for (in_planes, out_planes), stride in zip(planes, strides)
+ ]
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Feedforward pass."""
+ if len(x.shape) < 4:
+ x = x[(None,) * int(4 - len(x.shape))]
+ x = self.encoder(x)
+ if self.decoder is not None:
+ x = self.decoder(x)
+ return x