summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/vqvae')
-rw-r--r--text_recognizer/networks/vqvae/__init__.py1
-rw-r--r--text_recognizer/networks/vqvae/decoder.py93
-rw-r--r--text_recognizer/networks/vqvae/encoder.py85
-rw-r--r--text_recognizer/networks/vqvae/norm.py24
-rw-r--r--text_recognizer/networks/vqvae/residual.py54
-rw-r--r--text_recognizer/networks/vqvae/resize.py19
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py42
7 files changed, 0 insertions, 318 deletions
diff --git a/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py
deleted file mode 100644
index e1f05fa..0000000
--- a/text_recognizer/networks/vqvae/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""VQ-VAE module."""
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
deleted file mode 100644
index 7734a5a..0000000
--- a/text_recognizer/networks/vqvae/decoder.py
+++ /dev/null
@@ -1,93 +0,0 @@
-"""CNN decoder for the VQ-VAE."""
-from typing import Sequence
-
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.norm import Normalize
-from text_recognizer.networks.vqvae.residual import Residual
-
-
-class Decoder(nn.Module):
- """A CNN encoder network."""
-
- def __init__(
- self,
- out_channels: int,
- hidden_dim: int,
- channels_multipliers: Sequence[int],
- dropout_rate: float,
- activation: str = "mish",
- use_norm: bool = False,
- num_residuals: int = 4,
- residual_channels: int = 32,
- ) -> None:
- super().__init__()
- self.out_channels = out_channels
- self.hidden_dim = hidden_dim
- self.channels_multipliers = tuple(channels_multipliers)
- self.activation = activation
- self.dropout_rate = dropout_rate
- self.use_norm = use_norm
- self.num_residuals = num_residuals
- self.residual_channels = residual_channels
- self.decoder = self._build_decompression_block()
-
- def _build_decompression_block(self,) -> nn.Sequential:
- decoder = []
- in_channels = self.hidden_dim * self.channels_multipliers[0]
- for _ in range(self.num_residuals):
- decoder += [
- Residual(
- in_channels=in_channels,
- residual_channels=self.residual_channels,
- use_norm=self.use_norm,
- activation=self.activation,
- ),
- ]
-
- activation_fn = activation_function(self.activation)
- out_channels_multipliers = self.channels_multipliers + (1,)
- num_blocks = len(self.channels_multipliers)
-
- for i in range(num_blocks):
- in_channels = self.hidden_dim * self.channels_multipliers[i]
- out_channels = self.hidden_dim * out_channels_multipliers[i + 1]
- if self.use_norm:
- decoder += [
- Normalize(num_channels=in_channels,),
- ]
- decoder += [
- activation_fn,
- nn.ConvTranspose2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- ]
-
- if self.use_norm:
- decoder += [
- Normalize(
- num_channels=self.hidden_dim * out_channels_multipliers[-1],
- num_groups=self.hidden_dim * out_channels_multipliers[-1] // 4,
- ),
- ]
-
- decoder += [
- nn.Conv2d(
- in_channels=self.hidden_dim * out_channels_multipliers[-1],
- out_channels=self.out_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- ),
- ]
- return nn.Sequential(*decoder)
-
- def forward(self, z_q: Tensor) -> Tensor:
- """Reconstruct input from given codes."""
- return self.decoder(z_q)
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
deleted file mode 100644
index 4761486..0000000
--- a/text_recognizer/networks/vqvae/encoder.py
+++ /dev/null
@@ -1,85 +0,0 @@
-"""CNN encoder for the VQ-VAE."""
-from typing import List, Tuple
-
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.norm import Normalize
-from text_recognizer.networks.vqvae.residual import Residual
-
-
-class Encoder(nn.Module):
- """A CNN encoder network."""
-
- def __init__(
- self,
- in_channels: int,
- hidden_dim: int,
- channels_multipliers: List[int],
- dropout_rate: float,
- activation: str = "mish",
- use_norm: bool = False,
- num_residuals: int = 4,
- residual_channels: int = 32,
- ) -> None:
- super().__init__()
- self.in_channels = in_channels
- self.hidden_dim = hidden_dim
- self.channels_multipliers = tuple(channels_multipliers)
- self.activation = activation
- self.dropout_rate = dropout_rate
- self.use_norm = use_norm
- self.num_residuals = num_residuals
- self.residual_channels = residual_channels
- self.encoder = self._build_compression_block()
-
- def _build_compression_block(self) -> nn.Sequential:
- """Builds encoder network."""
- num_blocks = len(self.channels_multipliers)
- channels_multipliers = (1,) + self.channels_multipliers
- activation_fn = activation_function(self.activation)
-
- encoder = [
- nn.Conv2d(
- in_channels=self.in_channels,
- out_channels=self.hidden_dim,
- kernel_size=3,
- stride=1,
- padding=1,
- ),
- ]
-
- for i in range(num_blocks):
- in_channels = self.hidden_dim * channels_multipliers[i]
- out_channels = self.hidden_dim * channels_multipliers[i + 1]
- if self.use_norm:
- encoder += [
- Normalize(num_channels=in_channels,),
- ]
- encoder += [
- activation_fn,
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- ]
-
- for _ in range(self.num_residuals):
- encoder += [
- Residual(
- in_channels=out_channels,
- residual_channels=self.residual_channels,
- use_norm=self.use_norm,
- activation=self.activation,
- )
- ]
-
- return nn.Sequential(*encoder)
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes input into a discrete representation."""
- return self.encoder(x)
diff --git a/text_recognizer/networks/vqvae/norm.py b/text_recognizer/networks/vqvae/norm.py
deleted file mode 100644
index d73f9f8..0000000
--- a/text_recognizer/networks/vqvae/norm.py
+++ /dev/null
@@ -1,24 +0,0 @@
-"""Normalizer block."""
-import attr
-from torch import nn, Tensor
-
-
-@attr.s(eq=False)
-class Normalize(nn.Module):
- num_channels: int = attr.ib()
- num_groups: int = attr.ib(default=32)
- norm: nn.GroupNorm = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
- super().__init__()
- self.norm = nn.GroupNorm(
- num_groups=self.num_groups,
- num_channels=self.num_channels,
- eps=1.0e-6,
- affine=True,
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies group normalization."""
- return self.norm(x)
diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py
deleted file mode 100644
index bdff9eb..0000000
--- a/text_recognizer/networks/vqvae/residual.py
+++ /dev/null
@@ -1,54 +0,0 @@
-"""Residual block."""
-import attr
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.norm import Normalize
-
-
-@attr.s(eq=False)
-class Residual(nn.Module):
- in_channels: int = attr.ib()
- residual_channels: int = attr.ib()
- use_norm: bool = attr.ib(default=False)
- activation: str = attr.ib(default="relu")
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
- super().__init__()
- self.block = self._build_res_block()
-
- def _build_res_block(self) -> nn.Sequential:
- """Build residual block."""
- block = []
- activation_fn = activation_function(activation=self.activation)
-
- if self.use_norm:
- block.append(Normalize(num_channels=self.in_channels))
-
- block += [
- activation_fn,
- nn.Conv2d(
- self.in_channels,
- self.residual_channels,
- kernel_size=3,
- padding=1,
- bias=False,
- ),
- ]
-
- if self.use_norm:
- block.append(Normalize(num_channels=self.residual_channels))
-
- block += [
- activation_fn,
- nn.Conv2d(
- self.residual_channels, self.in_channels, kernel_size=1, bias=False
- ),
- ]
- return nn.Sequential(*block)
-
- def forward(self, x: Tensor) -> Tensor:
- """Apply the residual forward pass."""
- return x + self.block(x)
diff --git a/text_recognizer/networks/vqvae/resize.py b/text_recognizer/networks/vqvae/resize.py
deleted file mode 100644
index 8d67d02..0000000
--- a/text_recognizer/networks/vqvae/resize.py
+++ /dev/null
@@ -1,19 +0,0 @@
-"""Up and down-sample with linear interpolation."""
-from torch import nn, Tensor
-import torch.nn.functional as F
-
-
-class Upsample(nn.Module):
- """Upsamples by a factor 2."""
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies upsampling."""
- return F.interpolate(x, scale_factor=2.0, mode="nearest")
-
-
-class Downsample(nn.Module):
- """Downsampling by a factor 2."""
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies downsampling."""
- return F.avg_pool2d(x, kernel_size=2, stride=2)
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
deleted file mode 100644
index 5560e12..0000000
--- a/text_recognizer/networks/vqvae/vqvae.py
+++ /dev/null
@@ -1,42 +0,0 @@
-"""The VQ-VAE."""
-from typing import Tuple
-
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.quantizer.quantizer import VectorQuantizer
-
-
-class VQVAE(nn.Module):
- """Vector Quantized Variational AutoEncoder."""
-
- def __init__(
- self,
- encoder: nn.Module,
- decoder: nn.Module,
- quantizer: VectorQuantizer,
- ) -> None:
- super().__init__()
- self.encoder = encoder
- self.decoder = decoder
- self.quantizer = quantizer
-
- def encode(self, x: Tensor) -> Tensor:
- """Encodes input to a latent code."""
- return self.encoder(x)
-
- def quantize(self, z_e: Tensor) -> Tuple[Tensor, Tensor]:
- """Quantizes the encoded latent vectors."""
- z_q, _, commitment_loss = self.quantizer(z_e)
- return z_q, commitment_loss
-
- def decode(self, z_q: Tensor) -> Tensor:
- """Reconstructs input from latent codes."""
- return self.decoder(z_q)
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Compresses and decompresses input."""
- z_e = self.encode(x)
- z_q, commitment_loss = self.quantize(z_e)
- x_hat = self.decode(z_q)
- return x_hat, commitment_loss