summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/quantizer/__init__.py0
-rw-r--r--text_recognizer/networks/quantizer/codebook.py96
-rw-r--r--text_recognizer/networks/quantizer/kmeans.py32
-rw-r--r--text_recognizer/networks/quantizer/quantizer.py59
-rw-r--r--text_recognizer/networks/quantizer/utils.py26
-rw-r--r--text_recognizer/networks/vqvae/quantizer.py141
-rw-r--r--training/conf/experiment/vq_transformer_lines.yaml149
-rw-r--r--training/conf/network/quantizer.yaml12
8 files changed, 374 insertions, 141 deletions
diff --git a/text_recognizer/networks/quantizer/__init__.py b/text_recognizer/networks/quantizer/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/text_recognizer/networks/quantizer/__init__.py
diff --git a/text_recognizer/networks/quantizer/codebook.py b/text_recognizer/networks/quantizer/codebook.py
new file mode 100644
index 0000000..cb9bc59
--- /dev/null
+++ b/text_recognizer/networks/quantizer/codebook.py
@@ -0,0 +1,96 @@
+"""Codebook module."""
+from typing import Tuple
+
+import attr
+from einops import rearrange
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+
+from text_recognizer.networks.quantizer.kmeans import kmeans
+from text_recognizer.networks.quantizer.utils import (
+ ema_inplace,
+ norm,
+ sample_vectors,
+)
+
+
+@attr.s(eq=False)
+class CosineSimilarityCodebook(nn.Module):
+ """Cosine similarity codebook."""
+
+ dim: int = attr.ib()
+ codebook_size: int = attr.ib()
+ kmeans_init: bool = attr.ib(default=False)
+ kmeans_iters: int = attr.ib(default=10)
+ decay: float = attr.ib(default=0.8)
+ eps: float = attr.ib(default=1.0e-5)
+ threshold_dead: int = attr.ib(default=2)
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ def __attrs_post_init__(self) -> None:
+ if not self.kmeans_init:
+ embeddings = norm(torch.randn(self.codebook_size, self.dim))
+ else:
+ embeddings = torch.zeros(self.codebook_size, self.dim)
+ self.register_buffer("initalized", Tensor([not self.kmeans_init]))
+ self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
+ self.register_buffer("embeddings", embeddings)
+
+ def _initalize_embedding(self, data: Tensor) -> None:
+ embeddings, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
+ self.embeddings.data.copy_(embeddings)
+ self.cluster_size.data.copy_(cluster_size)
+ self.initalized.data.copy_(Tensor([True]))
+
+ def _replace(self, samples: Tensor, mask: Tensor) -> None:
+ samples = norm(samples)
+ modified_codebook = torch.where(
+ mask[..., None],
+ sample_vectors(samples, self.codebook_size),
+ self.embeddings,
+ )
+ self.embeddings.data.copy_(modified_codebook)
+
+ def _replace_dead_codes(self, batch_samples: Tensor) -> None:
+ if self.threshold_dead == 0:
+ return
+ dead_codes = self.cluster_size < self.threshold_dead
+ if not torch.any(dead_codes):
+ return
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
+ self._replace(batch_samples, mask=dead_codes)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Quantizes tensor."""
+ shape = x.shape
+ flatten = rearrange(x, "... d -> (...) d")
+ flatten = norm(flatten)
+
+ if not self.initalized:
+ self._initalize_embedding(flatten)
+
+ embeddings = norm(self.embeddings)
+ dist = flatten @ embeddings.t()
+ indices = dist.max(dim=-1).indices
+ one_hot = F.one_hot(indices, self.codebook_size).type_as(x)
+ indices = indices.view(*shape[:-1])
+
+ quantized = F.embedding(indices, self.embeddings)
+
+ if self.training:
+ bins = one_hot.sum(0)
+ ema_inplace(self.cluster_size, bins, self.decay)
+ zero_mask = bins == 0
+ bins = bins.masked_fill(zero_mask, 1.0)
+
+ embed_sum = flatten.t() @ one_hot
+ embed_norm = (embed_sum / bins.unsqueeze(0)).t()
+ embed_norm = norm(embed_norm)
+ embed_norm = torch.where(zero_mask[..., None], embeddings, embed_norm)
+ ema_inplace(self.embeddings, embed_norm, self.decay)
+ self._replace_dead_codes(x)
+
+ return quantized, indices
diff --git a/text_recognizer/networks/quantizer/kmeans.py b/text_recognizer/networks/quantizer/kmeans.py
new file mode 100644
index 0000000..a34c381
--- /dev/null
+++ b/text_recognizer/networks/quantizer/kmeans.py
@@ -0,0 +1,32 @@
+"""K-means clustering for embeddings."""
+from typing import Tuple
+
+from einops import repeat
+import torch
+from torch import Tensor
+
+from text_recognizer.networks.quantizer.utils import norm, sample_vectors
+
+
+def kmeans(
+ samples: Tensor, num_clusters: int, num_iters: int = 10
+) -> Tuple[Tensor, Tensor]:
+ """Compute k-means clusters."""
+ D = samples.shape[-1]
+
+ means = sample_vectors(samples, num_clusters)
+
+ for _ in range(num_iters):
+ dists = samples @ means.t()
+ buckets = dists.max(dim=-1).indices
+ bins = torch.bincount(buckets, minlength=num_clusters)
+ zero_mask = bins == 0
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
+
+ new_means = buckets.new_zeros(num_clusters, D).type_as(samples)
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=D), samples)
+ new_means /= bins_min_clamped[..., None]
+ new_means = norm(new_means)
+ means = torch.where(zero_mask[..., None], means, new_means)
+
+ return means, bins
diff --git a/text_recognizer/networks/quantizer/quantizer.py b/text_recognizer/networks/quantizer/quantizer.py
new file mode 100644
index 0000000..3e8f0b2
--- /dev/null
+++ b/text_recognizer/networks/quantizer/quantizer.py
@@ -0,0 +1,59 @@
+"""Implementation of a Vector Quantized Variational AutoEncoder.
+
+Reference:
+https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
+"""
+from typing import Tuple, Type
+
+import attr
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+
+
+@attr.s(eq=False)
+class VectorQuantizer(nn.Module):
+ """Vector quantizer."""
+
+ input_dim: int = attr.ib()
+ codebook: Type[nn.Module] = attr.ib()
+ commitment: float = attr.ib(default=1.0)
+ project_in: nn.Linear = attr.ib(default=None, init=False)
+ project_out: nn.Linear = attr.ib(default=None, init=False)
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ def __attrs_post_init__(self) -> None:
+ require_projection = self.codebook.dim != self.input_dim
+ self.project_in = (
+ nn.Linear(self.input_dim, self.codebook.dim)
+ if require_projection
+ else nn.Identity()
+ )
+ self.project_out = (
+ nn.Linear(self.codebook.dim, self.input_dim)
+ if require_projection
+ else nn.Identity()
+ )
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
+ """Quantizes latent vectors."""
+ H, W = x.shape[-2:]
+ x = rearrange(x, "b d h w -> b (h w) d")
+ x = self.project_in(x)
+
+ quantized, indices = self.codebook(x)
+
+ if self.training:
+ commitment_loss = F.mse_loss(quantized.detach(), x) * self.commitment
+ quantized = x + (quantized - x).detach()
+ else:
+ commitment_loss = torch.tensor([0.0]).type_as(x)
+
+ quantized = self.project_out(quantized)
+ quantized = rearrange(quantized, "b (h w) d -> b d h w", h=H, w=W)
+
+ return quantized, indices, commitment_loss
diff --git a/text_recognizer/networks/quantizer/utils.py b/text_recognizer/networks/quantizer/utils.py
new file mode 100644
index 0000000..0502d49
--- /dev/null
+++ b/text_recognizer/networks/quantizer/utils.py
@@ -0,0 +1,26 @@
+"""Helper functions for quantization."""
+from typing import Tuple
+
+import torch
+from torch import Tensor
+import torch.nn.functional as F
+
+
+def sample_vectors(samples: Tensor, num: int) -> Tensor:
+ """Subsamples a set of vectors."""
+ B, device = samples.shape[0], samples.device
+ if B >= num:
+ indices = torch.randperm(B, device=device)[:num]
+ else:
+ indices = torch.randint(0, B, (num,), device=device)[:num]
+ return samples[indices]
+
+
+def norm(t: Tensor) -> Tensor:
+ """Applies L2-normalization."""
+ return F.normalize(t, p=2, dim=-1)
+
+
+def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None:
+ """Applies exponential moving average."""
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
diff --git a/text_recognizer/networks/vqvae/quantizer.py b/text_recognizer/networks/vqvae/quantizer.py
deleted file mode 100644
index bba9b60..0000000
--- a/text_recognizer/networks/vqvae/quantizer.py
+++ /dev/null
@@ -1,141 +0,0 @@
-"""Implementation of a Vector Quantized Variational AutoEncoder.
-
-Reference:
-https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
-"""
-from einops import rearrange
-import torch
-from torch import nn
-from torch import Tensor
-import torch.nn.functional as F
-
-
-class EmbeddingEMA(nn.Module):
- """Embedding for Exponential Moving Average (EMA)."""
-
- def __init__(self, num_embeddings: int, embedding_dim: int) -> None:
- super().__init__()
- weight = torch.zeros(num_embeddings, embedding_dim)
- nn.init.kaiming_uniform_(weight, nonlinearity="linear")
- self.register_buffer("weight", weight)
- self.register_buffer("cluster_size", torch.zeros(num_embeddings))
- self.register_buffer("weight_avg", weight.clone())
-
-
-class VectorQuantizer(nn.Module):
- """The codebook that contains quantized vectors."""
-
- def __init__(
- self, num_embeddings: int, embedding_dim: int, decay: float = 0.99
- ) -> None:
- super().__init__()
- self.num_embeddings = num_embeddings
- self.embedding_dim = embedding_dim
- self.decay = decay
- self.embedding = EmbeddingEMA(self.num_embeddings, self.embedding_dim)
-
- def _discretization_bottleneck(self, latent: Tensor) -> Tensor:
- """Computes the code nearest to the latent representation.
-
- First we compute the posterior categorical distribution, and then map
- the latent representation to the nearest element of the embedding.
-
- Args:
- latent (Tensor): The latent representation.
-
- Shape:
- - latent :math:`(B x H x W, D)`
-
- Returns:
- Tensor: The quantized embedding vector.
-
- """
- # Store latent shape.
- b, h, w, d = latent.shape
-
- # Flatten the latent representation to 2D.
- latent = rearrange(latent, "b h w d -> (b h w) d")
-
- # Compute the L2 distance between the latents and the embeddings.
- l2_distance = (
- torch.sum(latent ** 2, dim=1, keepdim=True)
- + torch.sum(self.embedding.weight ** 2, dim=1)
- - 2 * latent @ self.embedding.weight.t()
- ) # [BHW x K]
-
- # Find the embedding k nearest to each latent.
- encoding_indices = torch.argmin(l2_distance, dim=1).unsqueeze(1) # [BHW, 1]
-
- # Convert to one-hot encodings, aka discrete bottleneck.
- one_hot_encoding = torch.zeros(
- encoding_indices.shape[0], self.num_embeddings, device=latent.device
- )
- one_hot_encoding.scatter_(1, encoding_indices, 1) # [BHW x K]
-
- # Embedding quantization.
- quantized_latent = one_hot_encoding @ self.embedding.weight # [BHW, D]
- quantized_latent = rearrange(
- quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w
- )
- if self.training:
- self._compute_ema(one_hot_encoding=one_hot_encoding, latent=latent)
-
- return quantized_latent
-
- def _compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None:
- """Computes the EMA update to the codebook."""
- batch_cluster_size = one_hot_encoding.sum(axis=0)
- batch_embedding_avg = (latent.t() @ one_hot_encoding).t()
- self.embedding.cluster_size.data.mul_(self.decay).add_(
- batch_cluster_size, alpha=1 - self.decay
- )
- self.embedding.weight_avg.data.mul_(self.decay).add_(
- batch_embedding_avg, alpha=1 - self.decay
- )
- new_embedding = self.embedding.weight_avg / (
- self.embedding.cluster_size + 1.0e-5
- ).unsqueeze(1)
- self.embedding.weight.data.copy_(new_embedding)
-
- def _commitment_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor:
- """Vector Quantization loss.
-
- The vector quantization algorithm allows us to create a codebook. The VQ
- algorithm works by moving the embedding vectors towards the encoder outputs.
-
- The embedding loss moves the embedding vector towards the encoder outputs. The
- .detach() works as the stop gradient (sg) described in the paper.
-
- Because the volume of the embedding space is dimensionless, it can arbitarily
- grow if the embeddings are not trained as fast as the encoder parameters. To
- mitigate this, a commitment loss is added in the second term which makes sure
- that the encoder commits to an embedding and that its output does not grow.
-
- Args:
- latent (Tensor): The encoder output.
- quantized_latent (Tensor): The quantized latent.
-
- Returns:
- Tensor: The combinded VQ loss.
-
- """
- loss = F.mse_loss(quantized_latent.detach(), latent)
- return loss
-
- def forward(self, latent: Tensor) -> Tensor:
- """Forward pass that returns the quantized vector and the vq loss."""
- # Rearrange latent representation s.t. the hidden dim is at the end.
- latent = rearrange(latent, "b d h w -> b h w d")
-
- # Maps latent to the nearest code in the codebook.
- quantized_latent = self._discretization_bottleneck(latent)
-
- loss = self._commitment_loss(latent, quantized_latent)
-
- # Add residue to the quantized latent.
- quantized_latent = latent + (quantized_latent - latent).detach()
-
- # Rearrange the quantized shape back to the original shape.
- quantized_latent = rearrange(quantized_latent, "b h w d -> b d h w")
-
- return quantized_latent, loss
diff --git a/training/conf/experiment/vq_transformer_lines.yaml b/training/conf/experiment/vq_transformer_lines.yaml
new file mode 100644
index 0000000..bbe1178
--- /dev/null
+++ b/training/conf/experiment/vq_transformer_lines.yaml
@@ -0,0 +1,149 @@
+# @package _global_
+
+defaults:
+ - override /mapping: null
+ - override /criterion: cross_entropy
+ - override /callbacks: htr
+ - override /datamodule: iam_lines
+ - override /network: null
+ - override /model: null
+ - override /lr_schedulers: null
+ - override /optimizers: null
+
+epochs: &epochs 512
+ignore_index: &ignore_index 3
+num_classes: &num_classes 57
+max_output_len: &max_output_len 89
+summary: [[1, 1, 56, 1024], [1, 89]]
+
+criterion:
+ ignore_index: *ignore_index
+
+mapping: &mapping
+ mapping:
+ _target_: text_recognizer.data.mappings.emnist.EmnistMapping
+
+callbacks:
+ stochastic_weight_averaging:
+ _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+ swa_epoch_start: 0.75
+ swa_lrs: 1.0e-5
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
+
+optimizers:
+ madgrad:
+ _target_: madgrad.MADGRAD
+ lr: 3.0e-4
+ momentum: 0.9
+ weight_decay: 0
+ eps: 1.0e-6
+ parameters: network
+
+lr_schedulers:
+ network:
+ _target_: torch.optim.lr_scheduler.CosineAnnealingLR
+ T_max: *epochs
+ eta_min: 1.0e-5
+ last_epoch: -1
+ interval: epoch
+ monitor: val/loss
+
+datamodule:
+ batch_size: 16
+ num_workers: 12
+ train_fraction: 0.9
+ pin_memory: true
+ << : *mapping
+
+rotary_embedding: &rotary_embedding
+ rotary_embedding:
+ _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
+ dim: 64
+
+attn: &attn
+ dim: &hidden_dim 512
+ num_heads: 4
+ dim_head: 64
+ dropout_rate: &dropout_rate 0.4
+
+network:
+ _target_: text_recognizer.networks.vq_transformer.VqTransformer
+ input_dims: [1, 56, 1024]
+ hidden_dim: *hidden_dim
+ num_classes: *num_classes
+ pad_index: *ignore_index
+ encoder:
+ _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
+ arch: b1
+ stochastic_dropout_rate: 0.2
+ bn_momentum: 0.99
+ bn_eps: 1.0e-3
+ decoder:
+ depth: 6
+ _target_: text_recognizer.networks.transformer.layers.Decoder
+ self_attn:
+ _target_: text_recognizer.networks.transformer.attention.Attention
+ << : *attn
+ causal: true
+ << : *rotary_embedding
+ cross_attn:
+ _target_: text_recognizer.networks.transformer.attention.Attention
+ << : *attn
+ causal: false
+ norm:
+ _target_: text_recognizer.networks.transformer.norm.ScaleNorm
+ normalized_shape: *hidden_dim
+ ff:
+ _target_: text_recognizer.networks.transformer.mlp.FeedForward
+ dim: *hidden_dim
+ dim_out: null
+ expansion_factor: 4
+ glu: true
+ dropout_rate: *dropout_rate
+ pre_norm: true
+ pixel_pos_embedding:
+ _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding
+ dim: *hidden_dim
+ shape: [1, 32]
+ quantizer:
+ _target_: text_recognizer.networks.quantizer.quantizer.VectorQuantizer
+ input_dim: 512
+ codebook:
+ _target_: text_recognizer.networks.quantizer.codebook.CosineSimilarityCodebook
+ dim: 16
+ codebook_size: 4096
+ kmeans_init: true
+ kmeans_iters: 10
+ decay: 0.8
+ eps: 1.0e-5
+ threshold_dead: 2
+ commitment: 1.0
+
+model:
+ _target_: text_recognizer.models.vq_transformer.VqTransformerLitModel
+ << : *mapping
+ max_output_len: *max_output_len
+ start_token: <s>
+ end_token: <e>
+ pad_token: <p>
+
+trainer:
+ _target_: pytorch_lightning.Trainer
+ stochastic_weight_avg: true
+ auto_scale_batch_size: binsearch
+ auto_lr_find: false
+ gradient_clip_val: 0.5
+ fast_dev_run: false
+ gpus: 1
+ precision: 16
+ max_epochs: *epochs
+ terminate_on_nan: true
+ weights_summary: null
+ limit_train_batches: 1.0
+ limit_val_batches: 1.0
+ limit_test_batches: 1.0
+ resume_from_checkpoint: null
+ accumulate_grad_batches: 1
+ overfit_batches: 0
diff --git a/training/conf/network/quantizer.yaml b/training/conf/network/quantizer.yaml
new file mode 100644
index 0000000..827a247
--- /dev/null
+++ b/training/conf/network/quantizer.yaml
@@ -0,0 +1,12 @@
+_target_: text_recognizer.networks.quantizer.quantizer.VectorQuantizer
+input_dim: 192
+codebook:
+ _target_: text_recognizer.networks.quantizer.codebook.CosineSimilarityCodebook
+ dim: 16
+ codebook_size: 2048
+ kmeans_init: true
+ kmeans_iters: 10
+ decay: 0.8
+ eps: 1.0e-5
+ threshold_dead: 2
+commitment: 1.0