summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/quantizer/__init__.py0
-rw-r--r--text_recognizer/networks/quantizer/codebook.py96
-rw-r--r--text_recognizer/networks/quantizer/kmeans.py32
-rw-r--r--text_recognizer/networks/quantizer/quantizer.py59
-rw-r--r--text_recognizer/networks/quantizer/utils.py26
-rw-r--r--text_recognizer/networks/vq_transformer.py84
-rw-r--r--text_recognizer/networks/vqvae/__init__.py1
-rw-r--r--text_recognizer/networks/vqvae/decoder.py93
-rw-r--r--text_recognizer/networks/vqvae/encoder.py85
-rw-r--r--text_recognizer/networks/vqvae/norm.py24
-rw-r--r--text_recognizer/networks/vqvae/residual.py54
-rw-r--r--text_recognizer/networks/vqvae/resize.py19
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py42
13 files changed, 0 insertions, 615 deletions
diff --git a/text_recognizer/networks/quantizer/__init__.py b/text_recognizer/networks/quantizer/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/text_recognizer/networks/quantizer/__init__.py
+++ /dev/null
diff --git a/text_recognizer/networks/quantizer/codebook.py b/text_recognizer/networks/quantizer/codebook.py
deleted file mode 100644
index cb9bc59..0000000
--- a/text_recognizer/networks/quantizer/codebook.py
+++ /dev/null
@@ -1,96 +0,0 @@
-"""Codebook module."""
-from typing import Tuple
-
-import attr
-from einops import rearrange
-import torch
-from torch import nn, Tensor
-import torch.nn.functional as F
-
-from text_recognizer.networks.quantizer.kmeans import kmeans
-from text_recognizer.networks.quantizer.utils import (
- ema_inplace,
- norm,
- sample_vectors,
-)
-
-
-@attr.s(eq=False)
-class CosineSimilarityCodebook(nn.Module):
- """Cosine similarity codebook."""
-
- dim: int = attr.ib()
- codebook_size: int = attr.ib()
- kmeans_init: bool = attr.ib(default=False)
- kmeans_iters: int = attr.ib(default=10)
- decay: float = attr.ib(default=0.8)
- eps: float = attr.ib(default=1.0e-5)
- threshold_dead: int = attr.ib(default=2)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
- def __attrs_post_init__(self) -> None:
- if not self.kmeans_init:
- embeddings = norm(torch.randn(self.codebook_size, self.dim))
- else:
- embeddings = torch.zeros(self.codebook_size, self.dim)
- self.register_buffer("initalized", Tensor([not self.kmeans_init]))
- self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
- self.register_buffer("embeddings", embeddings)
-
- def _initalize_embedding(self, data: Tensor) -> None:
- embeddings, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
- self.embeddings.data.copy_(embeddings)
- self.cluster_size.data.copy_(cluster_size)
- self.initalized.data.copy_(Tensor([True]))
-
- def _replace(self, samples: Tensor, mask: Tensor) -> None:
- samples = norm(samples)
- modified_codebook = torch.where(
- mask[..., None],
- sample_vectors(samples, self.codebook_size),
- self.embeddings,
- )
- self.embeddings.data.copy_(modified_codebook)
-
- def _replace_dead_codes(self, batch_samples: Tensor) -> None:
- if self.threshold_dead == 0:
- return
- dead_codes = self.cluster_size < self.threshold_dead
- if not torch.any(dead_codes):
- return
- batch_samples = rearrange(batch_samples, "... d -> (...) d")
- self._replace(batch_samples, mask=dead_codes)
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Quantizes tensor."""
- shape = x.shape
- flatten = rearrange(x, "... d -> (...) d")
- flatten = norm(flatten)
-
- if not self.initalized:
- self._initalize_embedding(flatten)
-
- embeddings = norm(self.embeddings)
- dist = flatten @ embeddings.t()
- indices = dist.max(dim=-1).indices
- one_hot = F.one_hot(indices, self.codebook_size).type_as(x)
- indices = indices.view(*shape[:-1])
-
- quantized = F.embedding(indices, self.embeddings)
-
- if self.training:
- bins = one_hot.sum(0)
- ema_inplace(self.cluster_size, bins, self.decay)
- zero_mask = bins == 0
- bins = bins.masked_fill(zero_mask, 1.0)
-
- embed_sum = flatten.t() @ one_hot
- embed_norm = (embed_sum / bins.unsqueeze(0)).t()
- embed_norm = norm(embed_norm)
- embed_norm = torch.where(zero_mask[..., None], embeddings, embed_norm)
- ema_inplace(self.embeddings, embed_norm, self.decay)
- self._replace_dead_codes(x)
-
- return quantized, indices
diff --git a/text_recognizer/networks/quantizer/kmeans.py b/text_recognizer/networks/quantizer/kmeans.py
deleted file mode 100644
index a34c381..0000000
--- a/text_recognizer/networks/quantizer/kmeans.py
+++ /dev/null
@@ -1,32 +0,0 @@
-"""K-means clustering for embeddings."""
-from typing import Tuple
-
-from einops import repeat
-import torch
-from torch import Tensor
-
-from text_recognizer.networks.quantizer.utils import norm, sample_vectors
-
-
-def kmeans(
- samples: Tensor, num_clusters: int, num_iters: int = 10
-) -> Tuple[Tensor, Tensor]:
- """Compute k-means clusters."""
- D = samples.shape[-1]
-
- means = sample_vectors(samples, num_clusters)
-
- for _ in range(num_iters):
- dists = samples @ means.t()
- buckets = dists.max(dim=-1).indices
- bins = torch.bincount(buckets, minlength=num_clusters)
- zero_mask = bins == 0
- bins_min_clamped = bins.masked_fill(zero_mask, 1)
-
- new_means = buckets.new_zeros(num_clusters, D).type_as(samples)
- new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=D), samples)
- new_means /= bins_min_clamped[..., None]
- new_means = norm(new_means)
- means = torch.where(zero_mask[..., None], means, new_means)
-
- return means, bins
diff --git a/text_recognizer/networks/quantizer/quantizer.py b/text_recognizer/networks/quantizer/quantizer.py
deleted file mode 100644
index 3e8f0b2..0000000
--- a/text_recognizer/networks/quantizer/quantizer.py
+++ /dev/null
@@ -1,59 +0,0 @@
-"""Implementation of a Vector Quantized Variational AutoEncoder.
-
-Reference:
-https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
-"""
-from typing import Tuple, Type
-
-import attr
-from einops import rearrange
-import torch
-from torch import nn
-from torch import Tensor
-import torch.nn.functional as F
-
-
-@attr.s(eq=False)
-class VectorQuantizer(nn.Module):
- """Vector quantizer."""
-
- input_dim: int = attr.ib()
- codebook: Type[nn.Module] = attr.ib()
- commitment: float = attr.ib(default=1.0)
- project_in: nn.Linear = attr.ib(default=None, init=False)
- project_out: nn.Linear = attr.ib(default=None, init=False)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
- def __attrs_post_init__(self) -> None:
- require_projection = self.codebook.dim != self.input_dim
- self.project_in = (
- nn.Linear(self.input_dim, self.codebook.dim)
- if require_projection
- else nn.Identity()
- )
- self.project_out = (
- nn.Linear(self.codebook.dim, self.input_dim)
- if require_projection
- else nn.Identity()
- )
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
- """Quantizes latent vectors."""
- H, W = x.shape[-2:]
- x = rearrange(x, "b d h w -> b (h w) d")
- x = self.project_in(x)
-
- quantized, indices = self.codebook(x)
-
- if self.training:
- commitment_loss = F.mse_loss(quantized.detach(), x) * self.commitment
- quantized = x + (quantized - x).detach()
- else:
- commitment_loss = torch.tensor([0.0]).type_as(x)
-
- quantized = self.project_out(quantized)
- quantized = rearrange(quantized, "b (h w) d -> b d h w", h=H, w=W)
-
- return quantized, indices, commitment_loss
diff --git a/text_recognizer/networks/quantizer/utils.py b/text_recognizer/networks/quantizer/utils.py
deleted file mode 100644
index 0502d49..0000000
--- a/text_recognizer/networks/quantizer/utils.py
+++ /dev/null
@@ -1,26 +0,0 @@
-"""Helper functions for quantization."""
-from typing import Tuple
-
-import torch
-from torch import Tensor
-import torch.nn.functional as F
-
-
-def sample_vectors(samples: Tensor, num: int) -> Tensor:
- """Subsamples a set of vectors."""
- B, device = samples.shape[0], samples.device
- if B >= num:
- indices = torch.randperm(B, device=device)[:num]
- else:
- indices = torch.randint(0, B, (num,), device=device)[:num]
- return samples[indices]
-
-
-def norm(t: Tensor) -> Tensor:
- """Applies L2-normalization."""
- return F.normalize(t, p=2, dim=-1)
-
-
-def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None:
- """Applies exponential moving average."""
- moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
deleted file mode 100644
index a2bd81b..0000000
--- a/text_recognizer/networks/vq_transformer.py
+++ /dev/null
@@ -1,84 +0,0 @@
-"""Vector quantized encoder, transformer decoder."""
-from typing import Optional, Tuple, Type
-
-from torch import nn, Tensor
-
-from text_recognizer.networks.conv_transformer import ConvTransformer
-from text_recognizer.networks.quantizer.quantizer import VectorQuantizer
-from text_recognizer.networks.transformer.layers import Decoder
-
-
-class VqTransformer(ConvTransformer):
- """Convolutional encoder and transformer decoder network."""
-
- def __init__(
- self,
- input_dims: Tuple[int, int, int],
- hidden_dim: int,
- num_classes: int,
- pad_index: Tensor,
- encoder: nn.Module,
- decoder: Decoder,
- pixel_pos_embedding: Type[nn.Module],
- quantizer: VectorQuantizer,
- token_pos_embedding: Optional[Type[nn.Module]] = None,
- ) -> None:
- super().__init__(
- input_dims=input_dims,
- hidden_dim=hidden_dim,
- num_classes=num_classes,
- pad_index=pad_index,
- encoder=encoder,
- decoder=decoder,
- pixel_pos_embedding=pixel_pos_embedding,
- token_pos_embedding=token_pos_embedding,
- )
- self.quantizer = quantizer
-
- def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes an image into a discrete (VQ) latent representation.
-
- Args:
- x (Tensor): Image tensor.
-
- Shape:
- - x: :math: `(B, C, H, W)`
- - z: :math: `(B, Sx, E)`
-
- where Sx is the length of the flattened feature maps projected from
- the encoder. E latent dimension for each pixel in the projected
- feature maps.
-
- Returns:
- Tensor: A Latent embedding of the image.
- """
- z = self.encoder(x)
- z = self.conv(z)
- z, _, commitment_loss = self.quantizer(z)
- z = self.pixel_pos_embedding(z)
- z = z.flatten(start_dim=2)
-
- # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
- z = z.permute(0, 2, 1)
- return z, commitment_loss
-
- def forward(self, x: Tensor, context: Tensor) -> Tensor:
- """Encodes images into word piece logtis.
-
- Args:
- x (Tensor): Input image(s).
- context (Tensor): Target word embeddings.
-
- Shapes:
- - x: :math: `(B, C, H, W)`
- - context: :math: `(B, Sy, T)`
-
- where B is the batch size, C is the number of input channels, H is
- the image height and W is the image width.
-
- Returns:
- Tensor: Sequence of logits.
- """
- z, commitment_loss = self.encode(x)
- logits = self.decode(z, context)
- return logits, commitment_loss
diff --git a/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py
deleted file mode 100644
index e1f05fa..0000000
--- a/text_recognizer/networks/vqvae/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""VQ-VAE module."""
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
deleted file mode 100644
index 7734a5a..0000000
--- a/text_recognizer/networks/vqvae/decoder.py
+++ /dev/null
@@ -1,93 +0,0 @@
-"""CNN decoder for the VQ-VAE."""
-from typing import Sequence
-
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.norm import Normalize
-from text_recognizer.networks.vqvae.residual import Residual
-
-
-class Decoder(nn.Module):
- """A CNN encoder network."""
-
- def __init__(
- self,
- out_channels: int,
- hidden_dim: int,
- channels_multipliers: Sequence[int],
- dropout_rate: float,
- activation: str = "mish",
- use_norm: bool = False,
- num_residuals: int = 4,
- residual_channels: int = 32,
- ) -> None:
- super().__init__()
- self.out_channels = out_channels
- self.hidden_dim = hidden_dim
- self.channels_multipliers = tuple(channels_multipliers)
- self.activation = activation
- self.dropout_rate = dropout_rate
- self.use_norm = use_norm
- self.num_residuals = num_residuals
- self.residual_channels = residual_channels
- self.decoder = self._build_decompression_block()
-
- def _build_decompression_block(self,) -> nn.Sequential:
- decoder = []
- in_channels = self.hidden_dim * self.channels_multipliers[0]
- for _ in range(self.num_residuals):
- decoder += [
- Residual(
- in_channels=in_channels,
- residual_channels=self.residual_channels,
- use_norm=self.use_norm,
- activation=self.activation,
- ),
- ]
-
- activation_fn = activation_function(self.activation)
- out_channels_multipliers = self.channels_multipliers + (1,)
- num_blocks = len(self.channels_multipliers)
-
- for i in range(num_blocks):
- in_channels = self.hidden_dim * self.channels_multipliers[i]
- out_channels = self.hidden_dim * out_channels_multipliers[i + 1]
- if self.use_norm:
- decoder += [
- Normalize(num_channels=in_channels,),
- ]
- decoder += [
- activation_fn,
- nn.ConvTranspose2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- ]
-
- if self.use_norm:
- decoder += [
- Normalize(
- num_channels=self.hidden_dim * out_channels_multipliers[-1],
- num_groups=self.hidden_dim * out_channels_multipliers[-1] // 4,
- ),
- ]
-
- decoder += [
- nn.Conv2d(
- in_channels=self.hidden_dim * out_channels_multipliers[-1],
- out_channels=self.out_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- ),
- ]
- return nn.Sequential(*decoder)
-
- def forward(self, z_q: Tensor) -> Tensor:
- """Reconstruct input from given codes."""
- return self.decoder(z_q)
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
deleted file mode 100644
index 4761486..0000000
--- a/text_recognizer/networks/vqvae/encoder.py
+++ /dev/null
@@ -1,85 +0,0 @@
-"""CNN encoder for the VQ-VAE."""
-from typing import List, Tuple
-
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.norm import Normalize
-from text_recognizer.networks.vqvae.residual import Residual
-
-
-class Encoder(nn.Module):
- """A CNN encoder network."""
-
- def __init__(
- self,
- in_channels: int,
- hidden_dim: int,
- channels_multipliers: List[int],
- dropout_rate: float,
- activation: str = "mish",
- use_norm: bool = False,
- num_residuals: int = 4,
- residual_channels: int = 32,
- ) -> None:
- super().__init__()
- self.in_channels = in_channels
- self.hidden_dim = hidden_dim
- self.channels_multipliers = tuple(channels_multipliers)
- self.activation = activation
- self.dropout_rate = dropout_rate
- self.use_norm = use_norm
- self.num_residuals = num_residuals
- self.residual_channels = residual_channels
- self.encoder = self._build_compression_block()
-
- def _build_compression_block(self) -> nn.Sequential:
- """Builds encoder network."""
- num_blocks = len(self.channels_multipliers)
- channels_multipliers = (1,) + self.channels_multipliers
- activation_fn = activation_function(self.activation)
-
- encoder = [
- nn.Conv2d(
- in_channels=self.in_channels,
- out_channels=self.hidden_dim,
- kernel_size=3,
- stride=1,
- padding=1,
- ),
- ]
-
- for i in range(num_blocks):
- in_channels = self.hidden_dim * channels_multipliers[i]
- out_channels = self.hidden_dim * channels_multipliers[i + 1]
- if self.use_norm:
- encoder += [
- Normalize(num_channels=in_channels,),
- ]
- encoder += [
- activation_fn,
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- ]
-
- for _ in range(self.num_residuals):
- encoder += [
- Residual(
- in_channels=out_channels,
- residual_channels=self.residual_channels,
- use_norm=self.use_norm,
- activation=self.activation,
- )
- ]
-
- return nn.Sequential(*encoder)
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes input into a discrete representation."""
- return self.encoder(x)
diff --git a/text_recognizer/networks/vqvae/norm.py b/text_recognizer/networks/vqvae/norm.py
deleted file mode 100644
index d73f9f8..0000000
--- a/text_recognizer/networks/vqvae/norm.py
+++ /dev/null
@@ -1,24 +0,0 @@
-"""Normalizer block."""
-import attr
-from torch import nn, Tensor
-
-
-@attr.s(eq=False)
-class Normalize(nn.Module):
- num_channels: int = attr.ib()
- num_groups: int = attr.ib(default=32)
- norm: nn.GroupNorm = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
- super().__init__()
- self.norm = nn.GroupNorm(
- num_groups=self.num_groups,
- num_channels=self.num_channels,
- eps=1.0e-6,
- affine=True,
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies group normalization."""
- return self.norm(x)
diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py
deleted file mode 100644
index bdff9eb..0000000
--- a/text_recognizer/networks/vqvae/residual.py
+++ /dev/null
@@ -1,54 +0,0 @@
-"""Residual block."""
-import attr
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.norm import Normalize
-
-
-@attr.s(eq=False)
-class Residual(nn.Module):
- in_channels: int = attr.ib()
- residual_channels: int = attr.ib()
- use_norm: bool = attr.ib(default=False)
- activation: str = attr.ib(default="relu")
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
- super().__init__()
- self.block = self._build_res_block()
-
- def _build_res_block(self) -> nn.Sequential:
- """Build residual block."""
- block = []
- activation_fn = activation_function(activation=self.activation)
-
- if self.use_norm:
- block.append(Normalize(num_channels=self.in_channels))
-
- block += [
- activation_fn,
- nn.Conv2d(
- self.in_channels,
- self.residual_channels,
- kernel_size=3,
- padding=1,
- bias=False,
- ),
- ]
-
- if self.use_norm:
- block.append(Normalize(num_channels=self.residual_channels))
-
- block += [
- activation_fn,
- nn.Conv2d(
- self.residual_channels, self.in_channels, kernel_size=1, bias=False
- ),
- ]
- return nn.Sequential(*block)
-
- def forward(self, x: Tensor) -> Tensor:
- """Apply the residual forward pass."""
- return x + self.block(x)
diff --git a/text_recognizer/networks/vqvae/resize.py b/text_recognizer/networks/vqvae/resize.py
deleted file mode 100644
index 8d67d02..0000000
--- a/text_recognizer/networks/vqvae/resize.py
+++ /dev/null
@@ -1,19 +0,0 @@
-"""Up and down-sample with linear interpolation."""
-from torch import nn, Tensor
-import torch.nn.functional as F
-
-
-class Upsample(nn.Module):
- """Upsamples by a factor 2."""
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies upsampling."""
- return F.interpolate(x, scale_factor=2.0, mode="nearest")
-
-
-class Downsample(nn.Module):
- """Downsampling by a factor 2."""
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies downsampling."""
- return F.avg_pool2d(x, kernel_size=2, stride=2)
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
deleted file mode 100644
index 5560e12..0000000
--- a/text_recognizer/networks/vqvae/vqvae.py
+++ /dev/null
@@ -1,42 +0,0 @@
-"""The VQ-VAE."""
-from typing import Tuple
-
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.quantizer.quantizer import VectorQuantizer
-
-
-class VQVAE(nn.Module):
- """Vector Quantized Variational AutoEncoder."""
-
- def __init__(
- self,
- encoder: nn.Module,
- decoder: nn.Module,
- quantizer: VectorQuantizer,
- ) -> None:
- super().__init__()
- self.encoder = encoder
- self.decoder = decoder
- self.quantizer = quantizer
-
- def encode(self, x: Tensor) -> Tensor:
- """Encodes input to a latent code."""
- return self.encoder(x)
-
- def quantize(self, z_e: Tensor) -> Tuple[Tensor, Tensor]:
- """Quantizes the encoded latent vectors."""
- z_q, _, commitment_loss = self.quantizer(z_e)
- return z_q, commitment_loss
-
- def decode(self, z_q: Tensor) -> Tensor:
- """Reconstructs input from latent codes."""
- return self.decoder(z_q)
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Compresses and decompresses input."""
- z_e = self.encode(x)
- z_q, commitment_loss = self.quantize(z_e)
- x_hat = self.decode(z_q)
- return x_hat, commitment_loss