summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transducer/tds_conv.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transducer/tds_conv.py')
-rw-r--r--text_recognizer/networks/transducer/tds_conv.py208
1 files changed, 0 insertions, 208 deletions
diff --git a/text_recognizer/networks/transducer/tds_conv.py b/text_recognizer/networks/transducer/tds_conv.py
deleted file mode 100644
index 5fb8ba9..0000000
--- a/text_recognizer/networks/transducer/tds_conv.py
+++ /dev/null
@@ -1,208 +0,0 @@
-"""Time-Depth Separable Convolutions.
-
-References:
- https://arxiv.org/abs/1904.02619
- https://arxiv.org/pdf/2010.01003.pdf
-
-Code stolen from:
- https://github.com/facebookresearch/gtn_applications
-
-
-"""
-from typing import List, Tuple
-
-from einops import rearrange
-import gtn
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class TDSBlock2d(nn.Module):
- """Internal block of a 2D TDSC network."""
-
- def __init__(
- self,
- in_channels: int,
- img_depth: int,
- kernel_size: Tuple[int],
- dropout_rate: float,
- ) -> None:
- super().__init__()
-
- self.in_channels = in_channels
- self.img_depth = img_depth
- self.kernel_size = kernel_size
- self.dropout_rate = dropout_rate
- self.fc_dim = in_channels * img_depth
-
- # Network placeholders.
- self.conv = None
- self.mlp = None
- self.instance_norm = None
-
- self._build_block()
-
- def _build_block(self) -> None:
- # Convolutional block.
- self.conv = nn.Sequential(
- nn.Conv3d(
- in_channels=self.in_channels,
- out_channels=self.in_channels,
- kernel_size=(1, self.kernel_size[0], self.kernel_size[1]),
- padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2),
- ),
- nn.ReLU(inplace=True),
- nn.Dropout(self.dropout_rate),
- )
-
- # MLP block.
- self.mlp = nn.Sequential(
- nn.Linear(self.fc_dim, self.fc_dim),
- nn.ReLU(inplace=True),
- nn.Dropout(self.dropout_rate),
- nn.Linear(self.fc_dim, self.fc_dim),
- nn.Dropout(self.dropout_rate),
- )
-
- # Instance norm.
- self.instance_norm = nn.ModuleList(
- [
- nn.InstanceNorm2d(self.fc_dim, affine=True),
- nn.InstanceNorm2d(self.fc_dim, affine=True),
- ]
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass.
-
- Args:
- x (Tensor): Input tensor.
-
- Shape:
- - x: :math: `(B, CD, H, W)`
-
- Returns:
- Tensor: Output tensor.
-
- """
- B, CD, H, W = x.shape
- C, D = self.in_channels, self.img_depth
- residual = x
- x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D)
- x = self.conv(x)
- x = rearrange(x, "b c d h w -> b (c d) h w")
- x += residual
-
- x = self.instance_norm[0](x)
-
- x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x
- x + self.instance_norm[1](x)
-
- # Output shape: [B, CD, H, W]
- return x
-
-
-class TDS2d(nn.Module):
- """TDS Netowrk.
-
- Structure is the following:
- Downsample layer -> TDS2d group -> ... -> Linear output layer
-
-
- """
-
- def __init__(
- self,
- input_dim: int,
- output_dim: int,
- depth: int,
- tds_groups: Tuple[int],
- kernel_size: Tuple[int],
- dropout_rate: float,
- in_channels: int = 1,
- ) -> None:
- super().__init__()
-
- self.in_channels = in_channels
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.depth = depth
- self.tds_groups = tds_groups
- self.kernel_size = kernel_size
- self.dropout_rate = dropout_rate
-
- self.tds = None
- self.fc = None
-
- self._build_network()
-
- def _build_network(self) -> None:
- in_channels = self.in_channels
- modules = []
- stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups])
- if self.input_dim % stride_h:
- raise RuntimeError(
- f"Image height not divisible by total stride {stride_h}."
- )
-
- for tds_group in self.tds_groups:
- # Add downsample layer.
- out_channels = self.depth * tds_group["channels"]
- modules.extend(
- [
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=self.kernel_size,
- padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2),
- stride=tds_group["stride"],
- ),
- nn.ReLU(inplace=True),
- nn.Dropout(self.dropout_rate),
- nn.InstanceNorm2d(out_channels, affine=True),
- ]
- )
-
- for _ in range(tds_group["num_blocks"]):
- modules.append(
- TDSBlock2d(
- tds_group["channels"],
- self.depth,
- self.kernel_size,
- self.dropout_rate,
- )
- )
-
- in_channels = out_channels
-
- self.tds = nn.Sequential(*modules)
- self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim)
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass.
-
- Args:
- x (Tensor): Input tensor.
-
- Shape:
- - x: :math: `(B, H, W)`
-
- Returns:
- Tensor: Output tensor.
-
- """
- if len(x.shape) == 4:
- x = x.squeeze(1) # Squeeze the channel dim away.
-
- B, H, W = x.shape
- x = rearrange(
- x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels
- )
- x = self.tds(x)
-
- # x shape: [B, C, H, W]
- x = rearrange(x, "b c h w -> b w (c h)")
-
- return self.fc(x)