summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/quantizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-17 22:42:58 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-17 22:42:58 +0100
commit2417288c9fe96264da708ce8d13ac7bc2faf83e3 (patch)
treea74fb7a8502b1642b71240608706efda14dee3f9 /text_recognizer/networks/quantizer
parent2cb2c5b38f0711267fecfe9c5e10940f4b4f79fc (diff)
Add new quantizer
Diffstat (limited to 'text_recognizer/networks/quantizer')
-rw-r--r--text_recognizer/networks/quantizer/__init__.py0
-rw-r--r--text_recognizer/networks/quantizer/codebook.py96
-rw-r--r--text_recognizer/networks/quantizer/kmeans.py32
-rw-r--r--text_recognizer/networks/quantizer/quantizer.py59
-rw-r--r--text_recognizer/networks/quantizer/utils.py26
5 files changed, 213 insertions, 0 deletions
diff --git a/text_recognizer/networks/quantizer/__init__.py b/text_recognizer/networks/quantizer/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/text_recognizer/networks/quantizer/__init__.py
diff --git a/text_recognizer/networks/quantizer/codebook.py b/text_recognizer/networks/quantizer/codebook.py
new file mode 100644
index 0000000..cb9bc59
--- /dev/null
+++ b/text_recognizer/networks/quantizer/codebook.py
@@ -0,0 +1,96 @@
+"""Codebook module."""
+from typing import Tuple
+
+import attr
+from einops import rearrange
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+
+from text_recognizer.networks.quantizer.kmeans import kmeans
+from text_recognizer.networks.quantizer.utils import (
+ ema_inplace,
+ norm,
+ sample_vectors,
+)
+
+
+@attr.s(eq=False)
+class CosineSimilarityCodebook(nn.Module):
+ """Cosine similarity codebook."""
+
+ dim: int = attr.ib()
+ codebook_size: int = attr.ib()
+ kmeans_init: bool = attr.ib(default=False)
+ kmeans_iters: int = attr.ib(default=10)
+ decay: float = attr.ib(default=0.8)
+ eps: float = attr.ib(default=1.0e-5)
+ threshold_dead: int = attr.ib(default=2)
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ def __attrs_post_init__(self) -> None:
+ if not self.kmeans_init:
+ embeddings = norm(torch.randn(self.codebook_size, self.dim))
+ else:
+ embeddings = torch.zeros(self.codebook_size, self.dim)
+ self.register_buffer("initalized", Tensor([not self.kmeans_init]))
+ self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
+ self.register_buffer("embeddings", embeddings)
+
+ def _initalize_embedding(self, data: Tensor) -> None:
+ embeddings, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
+ self.embeddings.data.copy_(embeddings)
+ self.cluster_size.data.copy_(cluster_size)
+ self.initalized.data.copy_(Tensor([True]))
+
+ def _replace(self, samples: Tensor, mask: Tensor) -> None:
+ samples = norm(samples)
+ modified_codebook = torch.where(
+ mask[..., None],
+ sample_vectors(samples, self.codebook_size),
+ self.embeddings,
+ )
+ self.embeddings.data.copy_(modified_codebook)
+
+ def _replace_dead_codes(self, batch_samples: Tensor) -> None:
+ if self.threshold_dead == 0:
+ return
+ dead_codes = self.cluster_size < self.threshold_dead
+ if not torch.any(dead_codes):
+ return
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
+ self._replace(batch_samples, mask=dead_codes)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Quantizes tensor."""
+ shape = x.shape
+ flatten = rearrange(x, "... d -> (...) d")
+ flatten = norm(flatten)
+
+ if not self.initalized:
+ self._initalize_embedding(flatten)
+
+ embeddings = norm(self.embeddings)
+ dist = flatten @ embeddings.t()
+ indices = dist.max(dim=-1).indices
+ one_hot = F.one_hot(indices, self.codebook_size).type_as(x)
+ indices = indices.view(*shape[:-1])
+
+ quantized = F.embedding(indices, self.embeddings)
+
+ if self.training:
+ bins = one_hot.sum(0)
+ ema_inplace(self.cluster_size, bins, self.decay)
+ zero_mask = bins == 0
+ bins = bins.masked_fill(zero_mask, 1.0)
+
+ embed_sum = flatten.t() @ one_hot
+ embed_norm = (embed_sum / bins.unsqueeze(0)).t()
+ embed_norm = norm(embed_norm)
+ embed_norm = torch.where(zero_mask[..., None], embeddings, embed_norm)
+ ema_inplace(self.embeddings, embed_norm, self.decay)
+ self._replace_dead_codes(x)
+
+ return quantized, indices
diff --git a/text_recognizer/networks/quantizer/kmeans.py b/text_recognizer/networks/quantizer/kmeans.py
new file mode 100644
index 0000000..a34c381
--- /dev/null
+++ b/text_recognizer/networks/quantizer/kmeans.py
@@ -0,0 +1,32 @@
+"""K-means clustering for embeddings."""
+from typing import Tuple
+
+from einops import repeat
+import torch
+from torch import Tensor
+
+from text_recognizer.networks.quantizer.utils import norm, sample_vectors
+
+
+def kmeans(
+ samples: Tensor, num_clusters: int, num_iters: int = 10
+) -> Tuple[Tensor, Tensor]:
+ """Compute k-means clusters."""
+ D = samples.shape[-1]
+
+ means = sample_vectors(samples, num_clusters)
+
+ for _ in range(num_iters):
+ dists = samples @ means.t()
+ buckets = dists.max(dim=-1).indices
+ bins = torch.bincount(buckets, minlength=num_clusters)
+ zero_mask = bins == 0
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
+
+ new_means = buckets.new_zeros(num_clusters, D).type_as(samples)
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=D), samples)
+ new_means /= bins_min_clamped[..., None]
+ new_means = norm(new_means)
+ means = torch.where(zero_mask[..., None], means, new_means)
+
+ return means, bins
diff --git a/text_recognizer/networks/quantizer/quantizer.py b/text_recognizer/networks/quantizer/quantizer.py
new file mode 100644
index 0000000..3e8f0b2
--- /dev/null
+++ b/text_recognizer/networks/quantizer/quantizer.py
@@ -0,0 +1,59 @@
+"""Implementation of a Vector Quantized Variational AutoEncoder.
+
+Reference:
+https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
+"""
+from typing import Tuple, Type
+
+import attr
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+
+
+@attr.s(eq=False)
+class VectorQuantizer(nn.Module):
+ """Vector quantizer."""
+
+ input_dim: int = attr.ib()
+ codebook: Type[nn.Module] = attr.ib()
+ commitment: float = attr.ib(default=1.0)
+ project_in: nn.Linear = attr.ib(default=None, init=False)
+ project_out: nn.Linear = attr.ib(default=None, init=False)
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ def __attrs_post_init__(self) -> None:
+ require_projection = self.codebook.dim != self.input_dim
+ self.project_in = (
+ nn.Linear(self.input_dim, self.codebook.dim)
+ if require_projection
+ else nn.Identity()
+ )
+ self.project_out = (
+ nn.Linear(self.codebook.dim, self.input_dim)
+ if require_projection
+ else nn.Identity()
+ )
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
+ """Quantizes latent vectors."""
+ H, W = x.shape[-2:]
+ x = rearrange(x, "b d h w -> b (h w) d")
+ x = self.project_in(x)
+
+ quantized, indices = self.codebook(x)
+
+ if self.training:
+ commitment_loss = F.mse_loss(quantized.detach(), x) * self.commitment
+ quantized = x + (quantized - x).detach()
+ else:
+ commitment_loss = torch.tensor([0.0]).type_as(x)
+
+ quantized = self.project_out(quantized)
+ quantized = rearrange(quantized, "b (h w) d -> b d h w", h=H, w=W)
+
+ return quantized, indices, commitment_loss
diff --git a/text_recognizer/networks/quantizer/utils.py b/text_recognizer/networks/quantizer/utils.py
new file mode 100644
index 0000000..0502d49
--- /dev/null
+++ b/text_recognizer/networks/quantizer/utils.py
@@ -0,0 +1,26 @@
+"""Helper functions for quantization."""
+from typing import Tuple
+
+import torch
+from torch import Tensor
+import torch.nn.functional as F
+
+
+def sample_vectors(samples: Tensor, num: int) -> Tensor:
+ """Subsamples a set of vectors."""
+ B, device = samples.shape[0], samples.device
+ if B >= num:
+ indices = torch.randperm(B, device=device)[:num]
+ else:
+ indices = torch.randint(0, B, (num,), device=device)[:num]
+ return samples[indices]
+
+
+def norm(t: Tensor) -> Tensor:
+ """Applies L2-normalization."""
+ return F.normalize(t, p=2, dim=-1)
+
+
+def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None:
+ """Applies exponential moving average."""
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))