diff options
Diffstat (limited to 'text_recognizer/networks/transducer/tds_conv.py')
-rw-r--r-- | text_recognizer/networks/transducer/tds_conv.py | 208 |
1 files changed, 0 insertions, 208 deletions
diff --git a/text_recognizer/networks/transducer/tds_conv.py b/text_recognizer/networks/transducer/tds_conv.py deleted file mode 100644 index 5fb8ba9..0000000 --- a/text_recognizer/networks/transducer/tds_conv.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Time-Depth Separable Convolutions. - -References: - https://arxiv.org/abs/1904.02619 - https://arxiv.org/pdf/2010.01003.pdf - -Code stolen from: - https://github.com/facebookresearch/gtn_applications - - -""" -from typing import List, Tuple - -from einops import rearrange -import gtn -import numpy as np -import torch -from torch import nn -from torch import Tensor - - -class TDSBlock2d(nn.Module): - """Internal block of a 2D TDSC network.""" - - def __init__( - self, - in_channels: int, - img_depth: int, - kernel_size: Tuple[int], - dropout_rate: float, - ) -> None: - super().__init__() - - self.in_channels = in_channels - self.img_depth = img_depth - self.kernel_size = kernel_size - self.dropout_rate = dropout_rate - self.fc_dim = in_channels * img_depth - - # Network placeholders. - self.conv = None - self.mlp = None - self.instance_norm = None - - self._build_block() - - def _build_block(self) -> None: - # Convolutional block. - self.conv = nn.Sequential( - nn.Conv3d( - in_channels=self.in_channels, - out_channels=self.in_channels, - kernel_size=(1, self.kernel_size[0], self.kernel_size[1]), - padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2), - ), - nn.ReLU(inplace=True), - nn.Dropout(self.dropout_rate), - ) - - # MLP block. - self.mlp = nn.Sequential( - nn.Linear(self.fc_dim, self.fc_dim), - nn.ReLU(inplace=True), - nn.Dropout(self.dropout_rate), - nn.Linear(self.fc_dim, self.fc_dim), - nn.Dropout(self.dropout_rate), - ) - - # Instance norm. - self.instance_norm = nn.ModuleList( - [ - nn.InstanceNorm2d(self.fc_dim, affine=True), - nn.InstanceNorm2d(self.fc_dim, affine=True), - ] - ) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass. - - Args: - x (Tensor): Input tensor. - - Shape: - - x: :math: `(B, CD, H, W)` - - Returns: - Tensor: Output tensor. - - """ - B, CD, H, W = x.shape - C, D = self.in_channels, self.img_depth - residual = x - x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D) - x = self.conv(x) - x = rearrange(x, "b c d h w -> b (c d) h w") - x += residual - - x = self.instance_norm[0](x) - - x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x - x + self.instance_norm[1](x) - - # Output shape: [B, CD, H, W] - return x - - -class TDS2d(nn.Module): - """TDS Netowrk. - - Structure is the following: - Downsample layer -> TDS2d group -> ... -> Linear output layer - - - """ - - def __init__( - self, - input_dim: int, - output_dim: int, - depth: int, - tds_groups: Tuple[int], - kernel_size: Tuple[int], - dropout_rate: float, - in_channels: int = 1, - ) -> None: - super().__init__() - - self.in_channels = in_channels - self.input_dim = input_dim - self.output_dim = output_dim - self.depth = depth - self.tds_groups = tds_groups - self.kernel_size = kernel_size - self.dropout_rate = dropout_rate - - self.tds = None - self.fc = None - - self._build_network() - - def _build_network(self) -> None: - in_channels = self.in_channels - modules = [] - stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups]) - if self.input_dim % stride_h: - raise RuntimeError( - f"Image height not divisible by total stride {stride_h}." - ) - - for tds_group in self.tds_groups: - # Add downsample layer. - out_channels = self.depth * tds_group["channels"] - modules.extend( - [ - nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=self.kernel_size, - padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2), - stride=tds_group["stride"], - ), - nn.ReLU(inplace=True), - nn.Dropout(self.dropout_rate), - nn.InstanceNorm2d(out_channels, affine=True), - ] - ) - - for _ in range(tds_group["num_blocks"]): - modules.append( - TDSBlock2d( - tds_group["channels"], - self.depth, - self.kernel_size, - self.dropout_rate, - ) - ) - - in_channels = out_channels - - self.tds = nn.Sequential(*modules) - self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass. - - Args: - x (Tensor): Input tensor. - - Shape: - - x: :math: `(B, H, W)` - - Returns: - Tensor: Output tensor. - - """ - if len(x.shape) == 4: - x = x.squeeze(1) # Squeeze the channel dim away. - - B, H, W = x.shape - x = rearrange( - x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels - ) - x = self.tds(x) - - # x shape: [B, C, H, W] - x = rearrange(x, "b c h w -> b w (c h)") - - return self.fc(x) |