summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2021-01-24 22:14:17 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2021-01-24 22:14:17 +0100
commit4a54d7e690897dd6e6c719fb908fd371a44c2952 (patch)
tree04722ac94b9c3960baa5db7939d7ef01dbf535a6 /src/text_recognizer/networks
parentd691b548cd0b6fc4ea184d64261f633789fee021 (diff)
Many updates, cool stuff on the way.
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r--src/text_recognizer/networks/__init__.py8
-rw-r--r--src/text_recognizer/networks/cnn.py101
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py15
-rw-r--r--src/text_recognizer/networks/metrics.py33
-rw-r--r--src/text_recognizer/networks/transducer/__init__.py2
-rw-r--r--src/text_recognizer/networks/transducer/tds_conv.py205
-rw-r--r--src/text_recognizer/networks/util.py9
-rw-r--r--src/text_recognizer/networks/vq_transformer.py150
-rw-r--r--src/text_recognizer/networks/vqvae/__init__.py4
-rw-r--r--src/text_recognizer/networks/vqvae/decoder.py133
-rw-r--r--src/text_recognizer/networks/vqvae/encoder.py125
-rw-r--r--src/text_recognizer/networks/vqvae/vector_quantizer.py2
-rw-r--r--src/text_recognizer/networks/vqvae/vqvae.py74
13 files changed, 832 insertions, 29 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index 2b624bb..bac5d28 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -1,4 +1,5 @@
"""Network modules."""
+from .cnn import CNN
from .cnn_transformer import CNNTransformer
from .crnn import ConvolutionalRecurrentNetwork
from .ctc import greedy_decoder
@@ -7,15 +8,19 @@ from .lenet import LeNet
from .metrics import accuracy, cer, wer
from .mlp import MLP
from .residual_network import ResidualNetwork, ResidualNetworkEncoder
+from .transducer import TDS2d
from .transformer import Transformer
from .unet import UNet
from .util import sliding_window
from .vit import ViT
+from .vq_transformer import VQTransformer
+from .vqvae import VQVAE
from .wide_resnet import WideResidualNetwork
__all__ = [
"accuracy",
"cer",
+ "CNN",
"CNNTransformer",
"ConvolutionalRecurrentNetwork",
"DenseNet",
@@ -27,8 +32,11 @@ __all__ = [
"ResidualNetworkEncoder",
"sliding_window",
"UNet",
+ "TDS2d",
"Transformer",
"ViT",
+ "VQTransformer",
+ "VQVAE",
"wer",
"WideResidualNetwork",
]
diff --git a/src/text_recognizer/networks/cnn.py b/src/text_recognizer/networks/cnn.py
new file mode 100644
index 0000000..1807bb9
--- /dev/null
+++ b/src/text_recognizer/networks/cnn.py
@@ -0,0 +1,101 @@
+"""Implementation of a simple backbone cnn network."""
+from typing import Callable, Dict, Optional, Tuple
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+
+from text_recognizer.networks.util import activation_function
+
+
+class CNN(nn.Module):
+ """LeNet network for character prediction."""
+
+ def __init__(
+ self,
+ channels: Tuple[int, ...] = (1, 32, 64, 128),
+ kernel_sizes: Tuple[int, ...] = (4, 4, 4),
+ strides: Tuple[int, ...] = (2, 2, 2),
+ max_pool_kernel: int = 2,
+ dropout_rate: float = 0.2,
+ activation: Optional[str] = "relu",
+ ) -> None:
+ """Initialization of the LeNet network.
+
+ Args:
+ channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
+ kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).
+ strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2).
+ max_pool_kernel (int): 2D max pooling kernel. Defaults to 2.
+ dropout_rate (float): The dropout rate. Defaults to 0.2.
+ activation (Optional[str]): The name of non-linear activation function. Defaults to relu.
+
+ Raises:
+ RuntimeError: if the number of hyperparameters does not match in length.
+
+ """
+ super().__init__()
+
+ if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides):
+ raise RuntimeError("The number of the hyperparameters does not match.")
+
+ self.cnn = self._build_network(
+ channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation,
+ )
+
+ def _build_network(
+ self,
+ channels: Tuple[int, ...],
+ kernel_sizes: Tuple[int, ...],
+ strides: Tuple[int, ...],
+ max_pool_kernel: int,
+ dropout_rate: float,
+ activation: str,
+ ) -> nn.Sequential:
+ # Load activation function.
+ activation_fn = activation_function(activation)
+
+ channels = list(channels)
+ in_channels = channels.pop(0)
+ configuration = zip(channels, kernel_sizes, strides)
+
+ modules = nn.ModuleList([])
+
+ for i, (out_channels, kernel_size, stride) in enumerate(configuration):
+ # Add max pool to reduce output size.
+ if i == len(channels) // 2:
+ modules.append(nn.MaxPool2d(max_pool_kernel))
+ if i == 0:
+ modules.append(
+ nn.Conv2d(
+ in_channels, out_channels, kernel_size, stride=stride, padding=1
+ )
+ )
+ else:
+ modules.append(
+ nn.Sequential(
+ activation_fn,
+ nn.BatchNorm2d(in_channels),
+ nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=1,
+ ),
+ )
+ )
+
+ if dropout_rate:
+ modules.append(nn.Dropout2d(p=dropout_rate))
+
+ in_channels = out_channels
+
+ return nn.Sequential(*modules)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """The feedforward pass."""
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ return self.cnn(x)
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index 43e5403..7133c26 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -29,14 +29,22 @@ class CNNTransformer(nn.Module):
backbone: str,
backbone_args: Optional[Dict] = None,
activation: str = "gelu",
+ pool_kernel: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__()
self.trg_pad_index = trg_pad_index
self.vocab_size = vocab_size
self.backbone = configure_backbone(backbone, backbone_args)
+
+ if pool_kernel is not None:
+ self.max_pool = nn.MaxPool2d(pool_kernel, stride=2)
+ else:
+ self.max_pool = None
+
self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+ self.pos_dropout = nn.Dropout(p=dropout_rate)
self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
nn.init.normal_(self.character_embedding.weight, std=0.02)
@@ -98,18 +106,23 @@ class CNNTransformer(nn.Module):
# If batch dimension is missing, it needs to be added.
if len(src.shape) < 4:
src = src[(None,) * (4 - len(src.shape))]
+
src = self.backbone(src)
+ if self.max_pool is not None:
+ src = self.max_pool(src)
+
if self.adaptive_pool is not None:
src = rearrange(src, "b c h w -> b w c h")
src = self.adaptive_pool(src)
src = src.squeeze(3)
else:
- src = rearrange(src, "b c h w -> b (w h) c")
+ src = rearrange(src, "b c h w -> b (h w) c")
b, t, _ = src.shape
src += self.src_position_embedding[:, :t]
+ src = self.pos_dropout(src)
return src
diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py
index ffad792..2605731 100644
--- a/src/text_recognizer/networks/metrics.py
+++ b/src/text_recognizer/networks/metrics.py
@@ -1,4 +1,7 @@
"""Utility functions for models."""
+from typing import Optional
+
+from einops import rearrange
import Levenshtein as Lev
import torch
from torch import Tensor
@@ -32,22 +35,33 @@ def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float:
return acc
-def cer(outputs: Tensor, targets: Tensor) -> float:
+def cer(
+ outputs: Tensor,
+ targets: Tensor,
+ batch_size: Optional[int] = None,
+ blank_label: Optional[int] = int,
+) -> float:
"""Computes the character error rate.
Args:
outputs (Tensor): The output from the network.
targets (Tensor): Ground truth labels.
+ batch_size (Optional[int]): Batch size if target and output has been flattend.
+ blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
Returns:
float: The cer for the batch.
"""
+ if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
+ targets = rearrange(targets, "(b t) -> b t", b=batch_size)
+ outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
+
target_lengths = torch.full(
size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
)
decoded_predictions, decoded_targets = greedy_decoder(
- outputs, targets, target_lengths
+ outputs, targets, target_lengths, blank_label=blank_label,
)
lev_dist = 0
@@ -63,22 +77,33 @@ def cer(outputs: Tensor, targets: Tensor) -> float:
return lev_dist / len(decoded_predictions)
-def wer(outputs: Tensor, targets: Tensor) -> float:
+def wer(
+ outputs: Tensor,
+ targets: Tensor,
+ batch_size: Optional[int] = None,
+ blank_label: Optional[int] = int,
+) -> float:
"""Computes the Word error rate.
Args:
outputs (Tensor): The output from the network.
targets (Tensor): Ground truth labels.
+ batch_size (optional[int]): Batch size if target and output has been flattend.
+ blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
Returns:
float: The wer for the batch.
"""
+ if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
+ targets = rearrange(targets, "(b t) -> b t", b=batch_size)
+ outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
+
target_lengths = torch.full(
size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
)
decoded_predictions, decoded_targets = greedy_decoder(
- outputs, targets, target_lengths
+ outputs, targets, target_lengths, blank_label=blank_label,
)
lev_dist = 0
diff --git a/src/text_recognizer/networks/transducer/__init__.py b/src/text_recognizer/networks/transducer/__init__.py
new file mode 100644
index 0000000..fdd6662
--- /dev/null
+++ b/src/text_recognizer/networks/transducer/__init__.py
@@ -0,0 +1,2 @@
+"""Transducer modules."""
+from .tds_conv import TDS2d
diff --git a/src/text_recognizer/networks/transducer/tds_conv.py b/src/text_recognizer/networks/transducer/tds_conv.py
new file mode 100644
index 0000000..018caf2
--- /dev/null
+++ b/src/text_recognizer/networks/transducer/tds_conv.py
@@ -0,0 +1,205 @@
+"""Time-Depth Separable Convolutions.
+
+References:
+ https://arxiv.org/abs/1904.02619
+ https://arxiv.org/pdf/2010.01003.pdf
+
+Code stolen from:
+ https://github.com/facebookresearch/gtn_applications
+
+
+"""
+from typing import List, Tuple
+
+from einops import rearrange
+import gtn
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class TDSBlock2d(nn.Module):
+ """Internal block of a 2D TDSC network."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ img_depth: int,
+ kernel_size: Tuple[int],
+ dropout_rate: float,
+ ) -> None:
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.img_depth = img_depth
+ self.kernel_size = kernel_size
+ self.dropout_rate = dropout_rate
+ self.fc_dim = in_channels * img_depth
+
+ # Network placeholders.
+ self.conv = None
+ self.mlp = None
+ self.instance_norm = None
+
+ self._build_block()
+
+ def _build_block(self) -> None:
+ # Convolutional block.
+ self.conv = nn.Sequential(
+ nn.Conv3d(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ kernel_size=(1, self.kernel_size[0], self.kernel_size[1]),
+ padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2),
+ ),
+ nn.ReLU(inplace=True),
+ nn.Dropout(self.dropout_rate),
+ )
+
+ # MLP block.
+ self.mlp = nn.Sequential(
+ nn.Linear(self.fc_dim, self.fc_dim),
+ nn.ReLU(inplace=True),
+ nn.Dropout(self.dropout_rate),
+ nn.Linear(self.fc_dim, self.fc_dim),
+ nn.Dropout(self.dropout_rate),
+ )
+
+ # Instance norm.
+ self.instance_norm = nn.ModuleList(
+ [
+ nn.InstanceNorm2d(self.fc_dim, affine=True),
+ nn.InstanceNorm2d(self.fc_dim, affine=True),
+ ]
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass.
+
+ Args:
+ x (Tensor): Input tensor.
+
+ Shape:
+ - x: :math: `(B, CD, H, W)`
+
+ Returns:
+ Tensor: Output tensor.
+
+ """
+ B, CD, H, W = x.shape
+ C, D = self.in_channels, self.img_depth
+ residual = x
+ x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D)
+ x = self.conv(x)
+ x = rearrange(x, "b c d h w -> b (c d) h w")
+ x += residual
+
+ x = self.instance_norm[0](x)
+
+ x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x
+ x + self.instance_norm[1](x)
+
+ # Output shape: [B, CD, H, W]
+ return x
+
+
+class TDS2d(nn.Module):
+ """TDS Netowrk.
+
+ Structure is the following:
+ Downsample layer -> TDS2d group -> ... -> Linear output layer
+
+
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ depth: int,
+ tds_groups: Tuple[int],
+ kernel_size: Tuple[int],
+ dropout_rate: float,
+ in_channels: int = 1,
+ ) -> None:
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.depth = depth
+ self.tds_groups = tds_groups
+ self.kernel_size = kernel_size
+ self.dropout_rate = dropout_rate
+
+ self.tds = None
+ self.fc = None
+
+ def _build_network(self) -> None:
+
+ modules = []
+ stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups])
+ if self.input_dim % stride_h:
+ raise RuntimeError(
+ f"Image height not divisible by total stride {stride_h}."
+ )
+
+ for tds_group in self.tds_groups:
+ # Add downsample layer.
+ out_channels = self.depth * tds_group["channels"]
+ modules.extend(
+ [
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=out_channels,
+ kernel_size=self.kernel_size,
+ padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2),
+ stride=tds_group["stride"],
+ ),
+ nn.ReLU(inplace=True),
+ nn.Dropout(self.dropout_rate),
+ nn.InstanceNorm2d(out_channels, affine=True),
+ ]
+ )
+
+ for _ in range(tds_group["num_blocks"]):
+ modules.append(
+ TDSBlock2d(
+ tds_group["channels"],
+ self.depth,
+ self.kernel_size,
+ self.dropout_rate,
+ )
+ )
+
+ self.in_channels = out_channels
+
+ self.tds = nn.Sequential(*modules)
+ self.fc = nn.Linear(
+ self.in_channels * self.input_dim // stride_h, self.output_dim
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass.
+
+ Args:
+ x (Tensor): Input tensor.
+
+ Shape:
+ - x: :math: `(B, H, W)`
+
+ Returns:
+ Tensor: Output tensor.
+
+ """
+ B, H, W = x.shape
+ x = rearrange(
+ x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels
+ )
+ x = self.tds(x)
+
+ # x shape: [B, C, H, W]
+ x = rearrange(x, "b c h w -> b w (c h)")
+
+ return self.fc(x)
diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py
index 711a952..131a6b4 100644
--- a/src/text_recognizer/networks/util.py
+++ b/src/text_recognizer/networks/util.py
@@ -65,13 +65,18 @@ def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]:
network_args = state_dict["network_args"]
weights = state_dict["model_state"]
+ freeze = False
+ if "freeze" in backbone_args and backbone_args["freeze"] is True:
+ backbone_args.pop("freeze")
+ freeze = True
+ network_args = backbone_args
+
# Initializes the network with trained weights.
backbone = backbone_(**network_args)
backbone.load_state_dict(weights)
- if "freeze" in backbone_args and backbone_args["freeze"] is True:
+ if freeze:
for params in backbone.parameters():
params.requires_grad = False
-
else:
backbone_ = getattr(network_module, backbone)
backbone = backbone_(**backbone_args)
diff --git a/src/text_recognizer/networks/vq_transformer.py b/src/text_recognizer/networks/vq_transformer.py
new file mode 100644
index 0000000..c673d96
--- /dev/null
+++ b/src/text_recognizer/networks/vq_transformer.py
@@ -0,0 +1,150 @@
+"""A VQ-Transformer for image to text recognition."""
+from typing import Dict, Optional, Tuple
+
+from einops import rearrange, repeat
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.util import configure_backbone
+from text_recognizer.networks.vqvae.encoder import _ResidualBlock
+
+
+class VQTransformer(nn.Module):
+ """VQ+Transfomer for image to character sequence prediction."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ vocab_size: int,
+ num_heads: int,
+ adaptive_pool_dim: Tuple,
+ expansion_dim: int,
+ dropout_rate: float,
+ trg_pad_index: int,
+ max_len: int,
+ backbone: str,
+ backbone_args: Optional[Dict] = None,
+ activation: str = "gelu",
+ ) -> None:
+ super().__init__()
+
+ # Configure vector quantized backbone.
+ self.backbone = configure_backbone(backbone, backbone_args)
+ self.conv = nn.Sequential(
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2),
+ nn.ReLU(inplace=True),
+ )
+
+ # Configure embeddings for Transformer network.
+ self.trg_pad_index = trg_pad_index
+ self.vocab_size = vocab_size
+ self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
+ self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+ self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+ nn.init.normal_(self.character_embedding.weight, std=0.02)
+
+ self.adaptive_pool = (
+ nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
+ )
+
+ self.transformer = Transformer(
+ num_encoder_layers,
+ num_decoder_layers,
+ hidden_dim,
+ num_heads,
+ expansion_dim,
+ dropout_rate,
+ activation,
+ )
+
+ self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+
+ def _create_trg_mask(self, trg: Tensor) -> Tensor:
+ # Move this outside the transformer.
+ trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+ trg_len = trg.shape[1]
+ trg_sub_mask = torch.tril(
+ torch.ones((trg_len, trg_len), device=trg.device)
+ ).bool()
+ trg_mask = trg_pad_mask & trg_sub_mask
+ return trg_mask
+
+ def encoder(self, src: Tensor) -> Tensor:
+ """Forward pass with the encoder of the transformer."""
+ return self.transformer.encoder(src)
+
+ def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+ """Forward pass with the decoder of the transformer + classification head."""
+ return self.head(
+ self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ )
+
+ def extract_image_features(self, src: Tensor) -> Tuple[Tensor, Tensor]:
+ """Extracts image features with a backbone neural network.
+
+ It seem like the winning idea was to swap channels and width dimension and collapse
+ the height dimension. The transformer is learning like a baby with this implementation!!! :D
+ Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
+
+ Args:
+ src (Tensor): Input tensor.
+
+ Returns:
+ Tensor: The input src to the transformer and the vq loss.
+
+ """
+ # If batch dimension is missing, it needs to be added.
+ if len(src.shape) < 4:
+ src = src[(None,) * (4 - len(src.shape))]
+ src, vq_loss = self.backbone.encode(src)
+ # src = self.backbone.decoder.res_block(src)
+ src = self.conv(src)
+
+ if self.adaptive_pool is not None:
+ src = rearrange(src, "b c h w -> b w c h")
+ src = self.adaptive_pool(src)
+ src = src.squeeze(3)
+ else:
+ src = rearrange(src, "b c h w -> b (w h) c")
+
+ b, t, _ = src.shape
+
+ src += self.src_position_embedding[:, :t]
+
+ return src, vq_loss
+
+ def target_embedding(self, trg: Tensor) -> Tensor:
+ """Encodes target tensor with embedding and postion.
+
+ Args:
+ trg (Tensor): Target tensor.
+
+ Returns:
+ Tensor: Encoded target tensor.
+
+ """
+ trg = self.character_embedding(trg.long())
+ trg = self.trg_position_encoding(trg)
+ return trg
+
+ def decode_image_features(
+ self, image_features: Tensor, trg: Optional[Tensor] = None
+ ) -> Tensor:
+ """Takes images features from the backbone and decodes them with the transformer."""
+ trg_mask = self._create_trg_mask(trg)
+ trg = self.target_embedding(trg)
+ out = self.transformer(image_features, trg, trg_mask=trg_mask)
+
+ logits = self.head(out)
+ return logits
+
+ def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ """Forward pass with CNN transfomer."""
+ image_features, vq_loss = self.extract_image_features(x)
+ logits = self.decode_image_features(image_features, trg)
+ return logits, vq_loss
diff --git a/src/text_recognizer/networks/vqvae/__init__.py b/src/text_recognizer/networks/vqvae/__init__.py
index e1f05fa..763953c 100644
--- a/src/text_recognizer/networks/vqvae/__init__.py
+++ b/src/text_recognizer/networks/vqvae/__init__.py
@@ -1 +1,5 @@
"""VQ-VAE module."""
+from .decoder import Decoder
+from .encoder import Encoder
+from .vector_quantizer import VectorQuantizer
+from .vqvae import VQVAE
diff --git a/src/text_recognizer/networks/vqvae/decoder.py b/src/text_recognizer/networks/vqvae/decoder.py
new file mode 100644
index 0000000..8847aba
--- /dev/null
+++ b/src/text_recognizer/networks/vqvae/decoder.py
@@ -0,0 +1,133 @@
+"""CNN decoder for the VQ-VAE."""
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.vqvae.encoder import _ResidualBlock
+
+
+class Decoder(nn.Module):
+ """A CNN encoder network."""
+
+ def __init__(
+ self,
+ channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ embedding_dim: int,
+ upsampling: Optional[List[List[int]]] = None,
+ activation: str = "leaky_relu",
+ dropout_rate: float = 0.0,
+ ) -> None:
+ super().__init__()
+
+ if dropout_rate:
+ if activation == "selu":
+ dropout = nn.AlphaDropout(p=dropout_rate)
+ else:
+ dropout = nn.Dropout(p=dropout_rate)
+ else:
+ dropout = None
+
+ self.upsampling = upsampling
+
+ self.res_block = nn.ModuleList([])
+ self.upsampling_block = nn.ModuleList([])
+
+ self.embedding_dim = embedding_dim
+ activation = activation_function(activation)
+
+ # Configure encoder.
+ self.decoder = self._build_decoder(
+ channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
+ )
+
+ def _build_decompression_block(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.ModuleList:
+ modules = nn.ModuleList([])
+ configuration = zip(channels, kernel_sizes, strides)
+ for i, (out_channels, kernel_size, stride) in enumerate(configuration):
+ modules.append(
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=1,
+ ),
+ activation,
+ )
+ )
+
+ if i < len(self.upsampling):
+ modules.append(nn.Upsample(size=self.upsampling[i]),)
+
+ if dropout is not None:
+ modules.append(dropout)
+
+ in_channels = out_channels
+
+ modules.extend(
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1
+ ),
+ nn.Tanh(),
+ )
+ )
+
+ return modules
+
+ def _build_decoder(
+ self,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.Sequential:
+
+ self.res_block.append(
+ nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
+ )
+
+ # Bottleneck module.
+ self.res_block.extend(
+ nn.ModuleList(
+ [
+ _ResidualBlock(channels[0], channels[0], dropout)
+ for i in range(num_residual_layers)
+ ]
+ )
+ )
+
+ # Decompression module
+ self.upsampling_block.extend(
+ self._build_decompression_block(
+ channels[0], channels[1:], kernel_sizes, strides, activation, dropout
+ )
+ )
+
+ self.res_block = nn.Sequential(*self.res_block)
+ self.upsampling_block = nn.Sequential(*self.upsampling_block)
+
+ return nn.Sequential(self.res_block, self.upsampling_block)
+
+ def forward(self, z_q: Tensor) -> Tensor:
+ """Reconstruct input from given codes."""
+ x_reconstruction = self.decoder(z_q)
+ return x_reconstruction
diff --git a/src/text_recognizer/networks/vqvae/encoder.py b/src/text_recognizer/networks/vqvae/encoder.py
index 60c4c43..d3adac5 100644
--- a/src/text_recognizer/networks/vqvae/encoder.py
+++ b/src/text_recognizer/networks/vqvae/encoder.py
@@ -1,6 +1,5 @@
"""CNN encoder for the VQ-VAE."""
-
-from typing import List, Optional, Type
+from typing import List, Optional, Tuple, Type
import torch
from torch import nn
@@ -12,16 +11,12 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
class _ResidualBlock(nn.Module):
def __init__(
- self,
- in_channels: int,
- out_channels: int,
- activation: Type[nn.Module],
- dropout: Optional[Type[nn.Module]],
+ self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
) -> None:
super().__init__()
self.block = [
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
- activation,
+ nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False),
]
@@ -42,23 +37,111 @@ class Encoder(nn.Module):
self,
in_channels: int,
channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
num_residual_layers: int,
embedding_dim: int,
num_embeddings: int,
beta: float = 0.25,
- activation: str = "elu",
+ activation: str = "leaky_relu",
dropout_rate: float = 0.0,
) -> None:
super().__init__()
- pass
- # if dropout_rate:
- # if activation == "selu":
- # dropout = nn.AlphaDropout(p=dropout_rate)
- # else:
- # dropout = nn.Dropout(p=dropout_rate)
- # else:
- # dropout = None
-
- def _build_encoder(self) -> nn.Sequential:
- # TODO: Continue to implement encoder.
- pass
+
+ if dropout_rate:
+ if activation == "selu":
+ dropout = nn.AlphaDropout(p=dropout_rate)
+ else:
+ dropout = nn.Dropout(p=dropout_rate)
+ else:
+ dropout = None
+
+ self.embedding_dim = embedding_dim
+ self.num_embeddings = num_embeddings
+ self.beta = beta
+ activation = activation_function(activation)
+
+ # Configure encoder.
+ self.encoder = self._build_encoder(
+ in_channels,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ activation,
+ dropout,
+ )
+
+ # Configure Vector Quantizer.
+ self.vector_quantizer = VectorQuantizer(
+ self.num_embeddings, self.embedding_dim, self.beta
+ )
+
+ def _build_compression_block(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.ModuleList:
+ modules = nn.ModuleList([])
+ configuration = zip(channels, kernel_sizes, strides)
+ for out_channels, kernel_size, stride in configuration:
+ modules.append(
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels, out_channels, kernel_size, stride=stride, padding=1
+ ),
+ activation,
+ )
+ )
+
+ if dropout is not None:
+ modules.append(dropout)
+
+ in_channels = out_channels
+
+ return modules
+
+ def _build_encoder(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.Sequential:
+ encoder = nn.ModuleList([])
+
+ # compression module
+ encoder.extend(
+ self._build_compression_block(
+ in_channels, channels, kernel_sizes, strides, activation, dropout
+ )
+ )
+
+ # Bottleneck module.
+ encoder.extend(
+ nn.ModuleList(
+ [
+ _ResidualBlock(channels[-1], channels[-1], dropout)
+ for i in range(num_residual_layers)
+ ]
+ )
+ )
+
+ encoder.append(
+ nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
+ )
+
+ return nn.Sequential(*encoder)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes input into a discrete representation."""
+ z_e = self.encoder(x)
+ z_q, vq_loss = self.vector_quantizer(z_e)
+ return z_q, vq_loss
diff --git a/src/text_recognizer/networks/vqvae/vector_quantizer.py b/src/text_recognizer/networks/vqvae/vector_quantizer.py
index 25e5583..f92c7ee 100644
--- a/src/text_recognizer/networks/vqvae/vector_quantizer.py
+++ b/src/text_recognizer/networks/vqvae/vector_quantizer.py
@@ -26,7 +26,7 @@ class VectorQuantizer(nn.Module):
self.embedding = nn.Embedding(self.K, self.D)
# Initialize the codebook.
- self.embedding.weight.uniform_(-1 / self.K, 1 / self.K)
+ nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K)
def discretization_bottleneck(self, latent: Tensor) -> Tensor:
"""Computes the code nearest to the latent representation.
diff --git a/src/text_recognizer/networks/vqvae/vqvae.py b/src/text_recognizer/networks/vqvae/vqvae.py
new file mode 100644
index 0000000..50448b4
--- /dev/null
+++ b/src/text_recognizer/networks/vqvae/vqvae.py
@@ -0,0 +1,74 @@
+"""The VQ-VAE."""
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.vqvae import Decoder, Encoder
+
+
+class VQVAE(nn.Module):
+ """Vector Quantized Variational AutoEncoder."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ embedding_dim: int,
+ num_embeddings: int,
+ upsampling: Optional[List[List[int]]] = None,
+ beta: float = 0.25,
+ activation: str = "leaky_relu",
+ dropout_rate: float = 0.0,
+ ) -> None:
+ super().__init__()
+
+ # configure encoder.
+ self.encoder = Encoder(
+ in_channels,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ embedding_dim,
+ num_embeddings,
+ beta,
+ activation,
+ dropout_rate,
+ )
+
+ # Configure decoder.
+ channels.reverse()
+ kernel_sizes.reverse()
+ strides.reverse()
+ self.decoder = Decoder(
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ embedding_dim,
+ upsampling,
+ activation,
+ dropout_rate,
+ )
+
+ def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes input to a latent code."""
+ return self.encoder(x)
+
+ def decode(self, z_q: Tensor) -> Tensor:
+ """Reconstructs input from latent codes."""
+ return self.decoder(z_q)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Compresses and decompresses input."""
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ z_q, vq_loss = self.encode(x)
+ x_reconstruction = self.decode(z_q)
+ return x_reconstruction, vq_loss