diff options
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r-- | src/text_recognizer/networks/__init__.py | 8 | ||||
-rw-r--r-- | src/text_recognizer/networks/cnn.py | 101 | ||||
-rw-r--r-- | src/text_recognizer/networks/cnn_transformer.py | 15 | ||||
-rw-r--r-- | src/text_recognizer/networks/metrics.py | 33 | ||||
-rw-r--r-- | src/text_recognizer/networks/transducer/__init__.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/networks/transducer/tds_conv.py | 205 | ||||
-rw-r--r-- | src/text_recognizer/networks/util.py | 9 | ||||
-rw-r--r-- | src/text_recognizer/networks/vq_transformer.py | 150 | ||||
-rw-r--r-- | src/text_recognizer/networks/vqvae/__init__.py | 4 | ||||
-rw-r--r-- | src/text_recognizer/networks/vqvae/decoder.py | 133 | ||||
-rw-r--r-- | src/text_recognizer/networks/vqvae/encoder.py | 125 | ||||
-rw-r--r-- | src/text_recognizer/networks/vqvae/vector_quantizer.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/networks/vqvae/vqvae.py | 74 |
13 files changed, 832 insertions, 29 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index 2b624bb..bac5d28 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -1,4 +1,5 @@ """Network modules.""" +from .cnn import CNN from .cnn_transformer import CNNTransformer from .crnn import ConvolutionalRecurrentNetwork from .ctc import greedy_decoder @@ -7,15 +8,19 @@ from .lenet import LeNet from .metrics import accuracy, cer, wer from .mlp import MLP from .residual_network import ResidualNetwork, ResidualNetworkEncoder +from .transducer import TDS2d from .transformer import Transformer from .unet import UNet from .util import sliding_window from .vit import ViT +from .vq_transformer import VQTransformer +from .vqvae import VQVAE from .wide_resnet import WideResidualNetwork __all__ = [ "accuracy", "cer", + "CNN", "CNNTransformer", "ConvolutionalRecurrentNetwork", "DenseNet", @@ -27,8 +32,11 @@ __all__ = [ "ResidualNetworkEncoder", "sliding_window", "UNet", + "TDS2d", "Transformer", "ViT", + "VQTransformer", + "VQVAE", "wer", "WideResidualNetwork", ] diff --git a/src/text_recognizer/networks/cnn.py b/src/text_recognizer/networks/cnn.py new file mode 100644 index 0000000..1807bb9 --- /dev/null +++ b/src/text_recognizer/networks/cnn.py @@ -0,0 +1,101 @@ +"""Implementation of a simple backbone cnn network.""" +from typing import Callable, Dict, Optional, Tuple + +from einops.layers.torch import Rearrange +import torch +from torch import nn + +from text_recognizer.networks.util import activation_function + + +class CNN(nn.Module): + """LeNet network for character prediction.""" + + def __init__( + self, + channels: Tuple[int, ...] = (1, 32, 64, 128), + kernel_sizes: Tuple[int, ...] = (4, 4, 4), + strides: Tuple[int, ...] = (2, 2, 2), + max_pool_kernel: int = 2, + dropout_rate: float = 0.2, + activation: Optional[str] = "relu", + ) -> None: + """Initialization of the LeNet network. + + Args: + channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). + kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). + strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2). + max_pool_kernel (int): 2D max pooling kernel. Defaults to 2. + dropout_rate (float): The dropout rate. Defaults to 0.2. + activation (Optional[str]): The name of non-linear activation function. Defaults to relu. + + Raises: + RuntimeError: if the number of hyperparameters does not match in length. + + """ + super().__init__() + + if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides): + raise RuntimeError("The number of the hyperparameters does not match.") + + self.cnn = self._build_network( + channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation, + ) + + def _build_network( + self, + channels: Tuple[int, ...], + kernel_sizes: Tuple[int, ...], + strides: Tuple[int, ...], + max_pool_kernel: int, + dropout_rate: float, + activation: str, + ) -> nn.Sequential: + # Load activation function. + activation_fn = activation_function(activation) + + channels = list(channels) + in_channels = channels.pop(0) + configuration = zip(channels, kernel_sizes, strides) + + modules = nn.ModuleList([]) + + for i, (out_channels, kernel_size, stride) in enumerate(configuration): + # Add max pool to reduce output size. + if i == len(channels) // 2: + modules.append(nn.MaxPool2d(max_pool_kernel)) + if i == 0: + modules.append( + nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=1 + ) + ) + else: + modules.append( + nn.Sequential( + activation_fn, + nn.BatchNorm2d(in_channels), + nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=1, + ), + ) + ) + + if dropout_rate: + modules.append(nn.Dropout2d(p=dropout_rate)) + + in_channels = out_channels + + return nn.Sequential(*modules) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """The feedforward pass.""" + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + return self.cnn(x) diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index 43e5403..7133c26 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -29,14 +29,22 @@ class CNNTransformer(nn.Module): backbone: str, backbone_args: Optional[Dict] = None, activation: str = "gelu", + pool_kernel: Optional[Tuple[int, int]] = None, ) -> None: super().__init__() self.trg_pad_index = trg_pad_index self.vocab_size = vocab_size self.backbone = configure_backbone(backbone, backbone_args) + + if pool_kernel is not None: + self.max_pool = nn.MaxPool2d(pool_kernel, stride=2) + else: + self.max_pool = None + self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) + self.pos_dropout = nn.Dropout(p=dropout_rate) self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) nn.init.normal_(self.character_embedding.weight, std=0.02) @@ -98,18 +106,23 @@ class CNNTransformer(nn.Module): # If batch dimension is missing, it needs to be added. if len(src.shape) < 4: src = src[(None,) * (4 - len(src.shape))] + src = self.backbone(src) + if self.max_pool is not None: + src = self.max_pool(src) + if self.adaptive_pool is not None: src = rearrange(src, "b c h w -> b w c h") src = self.adaptive_pool(src) src = src.squeeze(3) else: - src = rearrange(src, "b c h w -> b (w h) c") + src = rearrange(src, "b c h w -> b (h w) c") b, t, _ = src.shape src += self.src_position_embedding[:, :t] + src = self.pos_dropout(src) return src diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py index ffad792..2605731 100644 --- a/src/text_recognizer/networks/metrics.py +++ b/src/text_recognizer/networks/metrics.py @@ -1,4 +1,7 @@ """Utility functions for models.""" +from typing import Optional + +from einops import rearrange import Levenshtein as Lev import torch from torch import Tensor @@ -32,22 +35,33 @@ def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float: return acc -def cer(outputs: Tensor, targets: Tensor) -> float: +def cer( + outputs: Tensor, + targets: Tensor, + batch_size: Optional[int] = None, + blank_label: Optional[int] = int, +) -> float: """Computes the character error rate. Args: outputs (Tensor): The output from the network. targets (Tensor): Ground truth labels. + batch_size (Optional[int]): Batch size if target and output has been flattend. + blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. Returns: float: The cer for the batch. """ + if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: + targets = rearrange(targets, "(b t) -> b t", b=batch_size) + outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) + target_lengths = torch.full( size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, ) decoded_predictions, decoded_targets = greedy_decoder( - outputs, targets, target_lengths + outputs, targets, target_lengths, blank_label=blank_label, ) lev_dist = 0 @@ -63,22 +77,33 @@ def cer(outputs: Tensor, targets: Tensor) -> float: return lev_dist / len(decoded_predictions) -def wer(outputs: Tensor, targets: Tensor) -> float: +def wer( + outputs: Tensor, + targets: Tensor, + batch_size: Optional[int] = None, + blank_label: Optional[int] = int, +) -> float: """Computes the Word error rate. Args: outputs (Tensor): The output from the network. targets (Tensor): Ground truth labels. + batch_size (optional[int]): Batch size if target and output has been flattend. + blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. Returns: float: The wer for the batch. """ + if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: + targets = rearrange(targets, "(b t) -> b t", b=batch_size) + outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) + target_lengths = torch.full( size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, ) decoded_predictions, decoded_targets = greedy_decoder( - outputs, targets, target_lengths + outputs, targets, target_lengths, blank_label=blank_label, ) lev_dist = 0 diff --git a/src/text_recognizer/networks/transducer/__init__.py b/src/text_recognizer/networks/transducer/__init__.py new file mode 100644 index 0000000..fdd6662 --- /dev/null +++ b/src/text_recognizer/networks/transducer/__init__.py @@ -0,0 +1,2 @@ +"""Transducer modules.""" +from .tds_conv import TDS2d diff --git a/src/text_recognizer/networks/transducer/tds_conv.py b/src/text_recognizer/networks/transducer/tds_conv.py new file mode 100644 index 0000000..018caf2 --- /dev/null +++ b/src/text_recognizer/networks/transducer/tds_conv.py @@ -0,0 +1,205 @@ +"""Time-Depth Separable Convolutions. + +References: + https://arxiv.org/abs/1904.02619 + https://arxiv.org/pdf/2010.01003.pdf + +Code stolen from: + https://github.com/facebookresearch/gtn_applications + + +""" +from typing import List, Tuple + +from einops import rearrange +import gtn +import numpy as np +import torch +from torch import nn +from torch import Tensor + + +class TDSBlock2d(nn.Module): + """Internal block of a 2D TDSC network.""" + + def __init__( + self, + in_channels: int, + img_depth: int, + kernel_size: Tuple[int], + dropout_rate: float, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.img_depth = img_depth + self.kernel_size = kernel_size + self.dropout_rate = dropout_rate + self.fc_dim = in_channels * img_depth + + # Network placeholders. + self.conv = None + self.mlp = None + self.instance_norm = None + + self._build_block() + + def _build_block(self) -> None: + # Convolutional block. + self.conv = nn.Sequential( + nn.Conv3d( + in_channels=self.in_channels, + out_channels=self.in_channels, + kernel_size=(1, self.kernel_size[0], self.kernel_size[1]), + padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2), + ), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + ) + + # MLP block. + self.mlp = nn.Sequential( + nn.Linear(self.fc_dim, self.fc_dim), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + nn.Linear(self.fc_dim, self.fc_dim), + nn.Dropout(self.dropout_rate), + ) + + # Instance norm. + self.instance_norm = nn.ModuleList( + [ + nn.InstanceNorm2d(self.fc_dim, affine=True), + nn.InstanceNorm2d(self.fc_dim, affine=True), + ] + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. + + Args: + x (Tensor): Input tensor. + + Shape: + - x: :math: `(B, CD, H, W)` + + Returns: + Tensor: Output tensor. + + """ + B, CD, H, W = x.shape + C, D = self.in_channels, self.img_depth + residual = x + x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D) + x = self.conv(x) + x = rearrange(x, "b c d h w -> b (c d) h w") + x += residual + + x = self.instance_norm[0](x) + + x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x + x + self.instance_norm[1](x) + + # Output shape: [B, CD, H, W] + return x + + +class TDS2d(nn.Module): + """TDS Netowrk. + + Structure is the following: + Downsample layer -> TDS2d group -> ... -> Linear output layer + + + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + depth: int, + tds_groups: Tuple[int], + kernel_size: Tuple[int], + dropout_rate: float, + in_channels: int = 1, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.input_dim = input_dim + self.output_dim = output_dim + self.depth = depth + self.tds_groups = tds_groups + self.kernel_size = kernel_size + self.dropout_rate = dropout_rate + + self.tds = None + self.fc = None + + def _build_network(self) -> None: + + modules = [] + stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups]) + if self.input_dim % stride_h: + raise RuntimeError( + f"Image height not divisible by total stride {stride_h}." + ) + + for tds_group in self.tds_groups: + # Add downsample layer. + out_channels = self.depth * tds_group["channels"] + modules.extend( + [ + nn.Conv2d( + in_channels=self.in_channels, + out_channels=out_channels, + kernel_size=self.kernel_size, + padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2), + stride=tds_group["stride"], + ), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + nn.InstanceNorm2d(out_channels, affine=True), + ] + ) + + for _ in range(tds_group["num_blocks"]): + modules.append( + TDSBlock2d( + tds_group["channels"], + self.depth, + self.kernel_size, + self.dropout_rate, + ) + ) + + self.in_channels = out_channels + + self.tds = nn.Sequential(*modules) + self.fc = nn.Linear( + self.in_channels * self.input_dim // stride_h, self.output_dim + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. + + Args: + x (Tensor): Input tensor. + + Shape: + - x: :math: `(B, H, W)` + + Returns: + Tensor: Output tensor. + + """ + B, H, W = x.shape + x = rearrange( + x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels + ) + x = self.tds(x) + + # x shape: [B, C, H, W] + x = rearrange(x, "b c h w -> b w (c h)") + + return self.fc(x) diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py index 711a952..131a6b4 100644 --- a/src/text_recognizer/networks/util.py +++ b/src/text_recognizer/networks/util.py @@ -65,13 +65,18 @@ def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]: network_args = state_dict["network_args"] weights = state_dict["model_state"] + freeze = False + if "freeze" in backbone_args and backbone_args["freeze"] is True: + backbone_args.pop("freeze") + freeze = True + network_args = backbone_args + # Initializes the network with trained weights. backbone = backbone_(**network_args) backbone.load_state_dict(weights) - if "freeze" in backbone_args and backbone_args["freeze"] is True: + if freeze: for params in backbone.parameters(): params.requires_grad = False - else: backbone_ = getattr(network_module, backbone) backbone = backbone_(**backbone_args) diff --git a/src/text_recognizer/networks/vq_transformer.py b/src/text_recognizer/networks/vq_transformer.py new file mode 100644 index 0000000..c673d96 --- /dev/null +++ b/src/text_recognizer/networks/vq_transformer.py @@ -0,0 +1,150 @@ +"""A VQ-Transformer for image to text recognition.""" +from typing import Dict, Optional, Tuple + +from einops import rearrange, repeat +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer import PositionalEncoding, Transformer +from text_recognizer.networks.util import activation_function +from text_recognizer.networks.util import configure_backbone +from text_recognizer.networks.vqvae.encoder import _ResidualBlock + + +class VQTransformer(nn.Module): + """VQ+Transfomer for image to character sequence prediction.""" + + def __init__( + self, + num_encoder_layers: int, + num_decoder_layers: int, + hidden_dim: int, + vocab_size: int, + num_heads: int, + adaptive_pool_dim: Tuple, + expansion_dim: int, + dropout_rate: float, + trg_pad_index: int, + max_len: int, + backbone: str, + backbone_args: Optional[Dict] = None, + activation: str = "gelu", + ) -> None: + super().__init__() + + # Configure vector quantized backbone. + self.backbone = configure_backbone(backbone, backbone_args) + self.conv = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2), + nn.ReLU(inplace=True), + ) + + # Configure embeddings for Transformer network. + self.trg_pad_index = trg_pad_index + self.vocab_size = vocab_size + self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) + self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) + self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + nn.init.normal_(self.character_embedding.weight, std=0.02) + + self.adaptive_pool = ( + nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None + ) + + self.transformer = Transformer( + num_encoder_layers, + num_decoder_layers, + hidden_dim, + num_heads, + expansion_dim, + dropout_rate, + activation, + ) + + self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) + + def _create_trg_mask(self, trg: Tensor) -> Tensor: + # Move this outside the transformer. + trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] + trg_len = trg.shape[1] + trg_sub_mask = torch.tril( + torch.ones((trg_len, trg_len), device=trg.device) + ).bool() + trg_mask = trg_pad_mask & trg_sub_mask + return trg_mask + + def encoder(self, src: Tensor) -> Tensor: + """Forward pass with the encoder of the transformer.""" + return self.transformer.encoder(src) + + def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: + """Forward pass with the decoder of the transformer + classification head.""" + return self.head( + self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) + ) + + def extract_image_features(self, src: Tensor) -> Tuple[Tensor, Tensor]: + """Extracts image features with a backbone neural network. + + It seem like the winning idea was to swap channels and width dimension and collapse + the height dimension. The transformer is learning like a baby with this implementation!!! :D + Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D + + Args: + src (Tensor): Input tensor. + + Returns: + Tensor: The input src to the transformer and the vq loss. + + """ + # If batch dimension is missing, it needs to be added. + if len(src.shape) < 4: + src = src[(None,) * (4 - len(src.shape))] + src, vq_loss = self.backbone.encode(src) + # src = self.backbone.decoder.res_block(src) + src = self.conv(src) + + if self.adaptive_pool is not None: + src = rearrange(src, "b c h w -> b w c h") + src = self.adaptive_pool(src) + src = src.squeeze(3) + else: + src = rearrange(src, "b c h w -> b (w h) c") + + b, t, _ = src.shape + + src += self.src_position_embedding[:, :t] + + return src, vq_loss + + def target_embedding(self, trg: Tensor) -> Tensor: + """Encodes target tensor with embedding and postion. + + Args: + trg (Tensor): Target tensor. + + Returns: + Tensor: Encoded target tensor. + + """ + trg = self.character_embedding(trg.long()) + trg = self.trg_position_encoding(trg) + return trg + + def decode_image_features( + self, image_features: Tensor, trg: Optional[Tensor] = None + ) -> Tensor: + """Takes images features from the backbone and decodes them with the transformer.""" + trg_mask = self._create_trg_mask(trg) + trg = self.target_embedding(trg) + out = self.transformer(image_features, trg, trg_mask=trg_mask) + + logits = self.head(out) + return logits + + def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: + """Forward pass with CNN transfomer.""" + image_features, vq_loss = self.extract_image_features(x) + logits = self.decode_image_features(image_features, trg) + return logits, vq_loss diff --git a/src/text_recognizer/networks/vqvae/__init__.py b/src/text_recognizer/networks/vqvae/__init__.py index e1f05fa..763953c 100644 --- a/src/text_recognizer/networks/vqvae/__init__.py +++ b/src/text_recognizer/networks/vqvae/__init__.py @@ -1 +1,5 @@ """VQ-VAE module.""" +from .decoder import Decoder +from .encoder import Encoder +from .vector_quantizer import VectorQuantizer +from .vqvae import VQVAE diff --git a/src/text_recognizer/networks/vqvae/decoder.py b/src/text_recognizer/networks/vqvae/decoder.py new file mode 100644 index 0000000..8847aba --- /dev/null +++ b/src/text_recognizer/networks/vqvae/decoder.py @@ -0,0 +1,133 @@ +"""CNN decoder for the VQ-VAE.""" + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function +from text_recognizer.networks.vqvae.encoder import _ResidualBlock + + +class Decoder(nn.Module): + """A CNN encoder network.""" + + def __init__( + self, + channels: List[int], + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + embedding_dim: int, + upsampling: Optional[List[List[int]]] = None, + activation: str = "leaky_relu", + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + + if dropout_rate: + if activation == "selu": + dropout = nn.AlphaDropout(p=dropout_rate) + else: + dropout = nn.Dropout(p=dropout_rate) + else: + dropout = None + + self.upsampling = upsampling + + self.res_block = nn.ModuleList([]) + self.upsampling_block = nn.ModuleList([]) + + self.embedding_dim = embedding_dim + activation = activation_function(activation) + + # Configure encoder. + self.decoder = self._build_decoder( + channels, kernel_sizes, strides, num_residual_layers, activation, dropout, + ) + + def _build_decompression_block( + self, + in_channels: int, + channels: int, + kernel_sizes: List[int], + strides: List[int], + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.ModuleList: + modules = nn.ModuleList([]) + configuration = zip(channels, kernel_sizes, strides) + for i, (out_channels, kernel_size, stride) in enumerate(configuration): + modules.append( + nn.Sequential( + nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=1, + ), + activation, + ) + ) + + if i < len(self.upsampling): + modules.append(nn.Upsample(size=self.upsampling[i]),) + + if dropout is not None: + modules.append(dropout) + + in_channels = out_channels + + modules.extend( + nn.Sequential( + nn.ConvTranspose2d( + in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1 + ), + nn.Tanh(), + ) + ) + + return modules + + def _build_decoder( + self, + channels: int, + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.Sequential: + + self.res_block.append( + nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) + ) + + # Bottleneck module. + self.res_block.extend( + nn.ModuleList( + [ + _ResidualBlock(channels[0], channels[0], dropout) + for i in range(num_residual_layers) + ] + ) + ) + + # Decompression module + self.upsampling_block.extend( + self._build_decompression_block( + channels[0], channels[1:], kernel_sizes, strides, activation, dropout + ) + ) + + self.res_block = nn.Sequential(*self.res_block) + self.upsampling_block = nn.Sequential(*self.upsampling_block) + + return nn.Sequential(self.res_block, self.upsampling_block) + + def forward(self, z_q: Tensor) -> Tensor: + """Reconstruct input from given codes.""" + x_reconstruction = self.decoder(z_q) + return x_reconstruction diff --git a/src/text_recognizer/networks/vqvae/encoder.py b/src/text_recognizer/networks/vqvae/encoder.py index 60c4c43..d3adac5 100644 --- a/src/text_recognizer/networks/vqvae/encoder.py +++ b/src/text_recognizer/networks/vqvae/encoder.py @@ -1,6 +1,5 @@ """CNN encoder for the VQ-VAE.""" - -from typing import List, Optional, Type +from typing import List, Optional, Tuple, Type import torch from torch import nn @@ -12,16 +11,12 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer class _ResidualBlock(nn.Module): def __init__( - self, - in_channels: int, - out_channels: int, - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], + self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], ) -> None: super().__init__() self.block = [ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), - activation, + nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False), ] @@ -42,23 +37,111 @@ class Encoder(nn.Module): self, in_channels: int, channels: List[int], + kernel_sizes: List[int], + strides: List[int], num_residual_layers: int, embedding_dim: int, num_embeddings: int, beta: float = 0.25, - activation: str = "elu", + activation: str = "leaky_relu", dropout_rate: float = 0.0, ) -> None: super().__init__() - pass - # if dropout_rate: - # if activation == "selu": - # dropout = nn.AlphaDropout(p=dropout_rate) - # else: - # dropout = nn.Dropout(p=dropout_rate) - # else: - # dropout = None - - def _build_encoder(self) -> nn.Sequential: - # TODO: Continue to implement encoder. - pass + + if dropout_rate: + if activation == "selu": + dropout = nn.AlphaDropout(p=dropout_rate) + else: + dropout = nn.Dropout(p=dropout_rate) + else: + dropout = None + + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + self.beta = beta + activation = activation_function(activation) + + # Configure encoder. + self.encoder = self._build_encoder( + in_channels, + channels, + kernel_sizes, + strides, + num_residual_layers, + activation, + dropout, + ) + + # Configure Vector Quantizer. + self.vector_quantizer = VectorQuantizer( + self.num_embeddings, self.embedding_dim, self.beta + ) + + def _build_compression_block( + self, + in_channels: int, + channels: int, + kernel_sizes: List[int], + strides: List[int], + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.ModuleList: + modules = nn.ModuleList([]) + configuration = zip(channels, kernel_sizes, strides) + for out_channels, kernel_size, stride in configuration: + modules.append( + nn.Sequential( + nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=1 + ), + activation, + ) + ) + + if dropout is not None: + modules.append(dropout) + + in_channels = out_channels + + return modules + + def _build_encoder( + self, + in_channels: int, + channels: int, + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.Sequential: + encoder = nn.ModuleList([]) + + # compression module + encoder.extend( + self._build_compression_block( + in_channels, channels, kernel_sizes, strides, activation, dropout + ) + ) + + # Bottleneck module. + encoder.extend( + nn.ModuleList( + [ + _ResidualBlock(channels[-1], channels[-1], dropout) + for i in range(num_residual_layers) + ] + ) + ) + + encoder.append( + nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) + ) + + return nn.Sequential(*encoder) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes input into a discrete representation.""" + z_e = self.encoder(x) + z_q, vq_loss = self.vector_quantizer(z_e) + return z_q, vq_loss diff --git a/src/text_recognizer/networks/vqvae/vector_quantizer.py b/src/text_recognizer/networks/vqvae/vector_quantizer.py index 25e5583..f92c7ee 100644 --- a/src/text_recognizer/networks/vqvae/vector_quantizer.py +++ b/src/text_recognizer/networks/vqvae/vector_quantizer.py @@ -26,7 +26,7 @@ class VectorQuantizer(nn.Module): self.embedding = nn.Embedding(self.K, self.D) # Initialize the codebook. - self.embedding.weight.uniform_(-1 / self.K, 1 / self.K) + nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K) def discretization_bottleneck(self, latent: Tensor) -> Tensor: """Computes the code nearest to the latent representation. diff --git a/src/text_recognizer/networks/vqvae/vqvae.py b/src/text_recognizer/networks/vqvae/vqvae.py new file mode 100644 index 0000000..50448b4 --- /dev/null +++ b/src/text_recognizer/networks/vqvae/vqvae.py @@ -0,0 +1,74 @@ +"""The VQ-VAE.""" + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.vqvae import Decoder, Encoder + + +class VQVAE(nn.Module): + """Vector Quantized Variational AutoEncoder.""" + + def __init__( + self, + in_channels: int, + channels: List[int], + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + embedding_dim: int, + num_embeddings: int, + upsampling: Optional[List[List[int]]] = None, + beta: float = 0.25, + activation: str = "leaky_relu", + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + + # configure encoder. + self.encoder = Encoder( + in_channels, + channels, + kernel_sizes, + strides, + num_residual_layers, + embedding_dim, + num_embeddings, + beta, + activation, + dropout_rate, + ) + + # Configure decoder. + channels.reverse() + kernel_sizes.reverse() + strides.reverse() + self.decoder = Decoder( + channels, + kernel_sizes, + strides, + num_residual_layers, + embedding_dim, + upsampling, + activation, + dropout_rate, + ) + + def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes input to a latent code.""" + return self.encoder(x) + + def decode(self, z_q: Tensor) -> Tensor: + """Reconstructs input from latent codes.""" + return self.decoder(z_q) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Compresses and decompresses input.""" + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + z_q, vq_loss = self.encode(x) + x_reconstruction = self.decode(z_q) + return x_reconstruction, vq_loss |