summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/quantizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-27 18:17:16 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-27 18:17:16 +0200
commit68cda1e6abf76312eb13f45ca3c8565a0b00d745 (patch)
tree4a3a26acc70901675ac92b1f9e2b71685b107864 /text_recognizer/networks/quantizer
parentaeadb9d82c577879ab8110eb20a9a12d6ca6750c (diff)
Add quantizer
Diffstat (limited to 'text_recognizer/networks/quantizer')
-rw-r--r--text_recognizer/networks/quantizer/__init__.py2
-rw-r--r--text_recognizer/networks/quantizer/cosine_codebook.py104
-rw-r--r--text_recognizer/networks/quantizer/kmeans.py32
-rw-r--r--text_recognizer/networks/quantizer/quantizer.py69
-rw-r--r--text_recognizer/networks/quantizer/utils.py50
5 files changed, 257 insertions, 0 deletions
diff --git a/text_recognizer/networks/quantizer/__init__.py b/text_recognizer/networks/quantizer/__init__.py
new file mode 100644
index 0000000..9c8685c
--- /dev/null
+++ b/text_recognizer/networks/quantizer/__init__.py
@@ -0,0 +1,2 @@
+from text_recognizer.networks.quantizer.cosine_codebook import CosineSimilarityCodebook
+from text_recognizer.networks.quantizer.quantizer import VectorQuantizer
diff --git a/text_recognizer/networks/quantizer/cosine_codebook.py b/text_recognizer/networks/quantizer/cosine_codebook.py
new file mode 100644
index 0000000..3b6af0f
--- /dev/null
+++ b/text_recognizer/networks/quantizer/cosine_codebook.py
@@ -0,0 +1,104 @@
+"""Codebook module."""
+from typing import Tuple
+
+from einops import rearrange
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+
+from text_recognizer.networks.quantizer.kmeans import kmeans
+from text_recognizer.networks.quantizer.utils import (
+ ema_inplace,
+ norm,
+ sample_vectors,
+ gumbel_sample,
+)
+
+
+class CosineSimilarityCodebook(nn.Module):
+ """Cosine similarity codebook."""
+
+ def __init__(
+ self,
+ dim: int,
+ codebook_size: int,
+ kmeans_init: bool = False,
+ kmeans_iters: int = 10,
+ decay: float = 0.8,
+ eps: float = 1.0e-5,
+ threshold_dead: int = 2,
+ temperature: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.dim = dim
+ self.codebook_size = codebook_size
+ self.kmeans_init = kmeans_init
+ self.kmeans_iters = kmeans_iters
+ self.decay = decay
+ self.eps = eps
+ self.threshold_dead = threshold_dead
+ self.temperature = temperature
+
+ if not self.kmeans_init:
+ embeddings = norm(torch.randn(self.codebook_size, self.dim))
+ else:
+ embeddings = torch.zeros(self.codebook_size, self.dim)
+ self.register_buffer("initalized", Tensor([not self.kmeans_init]))
+ self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
+ self.register_buffer("embeddings", embeddings)
+
+ def _initalize_embedding(self, data: Tensor) -> None:
+ embeddings, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
+ self.embeddings.data.copy_(embeddings)
+ self.cluster_size.data.copy_(cluster_size)
+ self.initalized.data.copy_(Tensor([True]))
+
+ def _replace(self, samples: Tensor, mask: Tensor) -> None:
+ samples = norm(samples)
+ modified_codebook = torch.where(
+ mask[..., None],
+ sample_vectors(samples, self.codebook_size),
+ self.embeddings,
+ )
+ self.embeddings.data.copy_(modified_codebook)
+
+ def _replace_dead_codes(self, batch_samples: Tensor) -> None:
+ if self.threshold_dead == 0:
+ return
+ dead_codes = self.cluster_size < self.threshold_dead
+ if not torch.any(dead_codes):
+ return
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
+ self._replace(batch_samples, mask=dead_codes)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Quantizes tensor."""
+ shape = x.shape
+ flatten = rearrange(x, "... d -> (...) d")
+ flatten = norm(flatten)
+
+ if not self.initalized:
+ self._initalize_embedding(flatten)
+
+ embeddings = norm(self.embeddings)
+ dist = flatten @ embeddings.t()
+ indices = gumbel_sample(dist, dim=-1, temperature=self.temperature)
+ one_hot = F.one_hot(indices, self.codebook_size).type_as(x)
+ indices = indices.view(*shape[:-1])
+
+ quantized = F.embedding(indices, self.embeddings)
+
+ if self.training:
+ bins = one_hot.sum(0)
+ ema_inplace(self.cluster_size, bins, self.decay)
+ zero_mask = bins == 0
+ bins = bins.masked_fill(zero_mask, 1.0)
+
+ embed_sum = flatten.t() @ one_hot
+ embed_norm = (embed_sum / bins.unsqueeze(0)).t()
+ embed_norm = norm(embed_norm)
+ embed_norm = torch.where(zero_mask[..., None], embeddings, embed_norm)
+ ema_inplace(self.embeddings, embed_norm, self.decay)
+ self._replace_dead_codes(x)
+
+ return quantized, indices
diff --git a/text_recognizer/networks/quantizer/kmeans.py b/text_recognizer/networks/quantizer/kmeans.py
new file mode 100644
index 0000000..a34c381
--- /dev/null
+++ b/text_recognizer/networks/quantizer/kmeans.py
@@ -0,0 +1,32 @@
+"""K-means clustering for embeddings."""
+from typing import Tuple
+
+from einops import repeat
+import torch
+from torch import Tensor
+
+from text_recognizer.networks.quantizer.utils import norm, sample_vectors
+
+
+def kmeans(
+ samples: Tensor, num_clusters: int, num_iters: int = 10
+) -> Tuple[Tensor, Tensor]:
+ """Compute k-means clusters."""
+ D = samples.shape[-1]
+
+ means = sample_vectors(samples, num_clusters)
+
+ for _ in range(num_iters):
+ dists = samples @ means.t()
+ buckets = dists.max(dim=-1).indices
+ bins = torch.bincount(buckets, minlength=num_clusters)
+ zero_mask = bins == 0
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
+
+ new_means = buckets.new_zeros(num_clusters, D).type_as(samples)
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=D), samples)
+ new_means /= bins_min_clamped[..., None]
+ new_means = norm(new_means)
+ means = torch.where(zero_mask[..., None], means, new_means)
+
+ return means, bins
diff --git a/text_recognizer/networks/quantizer/quantizer.py b/text_recognizer/networks/quantizer/quantizer.py
new file mode 100644
index 0000000..2c07b79
--- /dev/null
+++ b/text_recognizer/networks/quantizer/quantizer.py
@@ -0,0 +1,69 @@
+from typing import Optional, Tuple, Type
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+
+from text_recognizer.networks.quantizer.utils import orthgonal_loss_fn
+
+
+class VectorQuantizer(nn.Module):
+ """Vector quantizer."""
+
+ def __init__(
+ self,
+ input_dim: int,
+ codebook: Type[nn.Module],
+ commitment: float = 1.0,
+ ort_reg_weight: float = 0,
+ ort_reg_max_codes: Optional[int] = None,
+ ) -> None:
+ super().__init__()
+ self.input_dim = input_dim
+ self.codebook = codebook
+ self.commitment = commitment
+ self.ort_reg_weight = ort_reg_weight
+ self.ort_reg_max_codes = ort_reg_max_codes
+ require_projection = self.codebook.dim != self.input_dim
+ self.project_in = (
+ nn.Linear(self.input_dim, self.codebook.dim)
+ if require_projection
+ else nn.Identity()
+ )
+ self.project_out = (
+ nn.Linear(self.codebook.dim, self.input_dim)
+ if require_projection
+ else nn.Identity()
+ )
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
+ """Quantizes latent vectors."""
+ H, W = x.shape[-2:]
+ device = x.device
+ x = rearrange(x, "b d h w -> b (h w) d")
+ x = self.project_in(x)
+
+ quantized, indices = self.codebook(x)
+
+ if self.training:
+ loss = F.mse_loss(quantized.detach(), x) * self.commitment
+ quantized = x + (quantized - x).detach()
+ if self.ort_reg_weight > 0:
+ codebook = self.codebook.embeddings
+ num_codes = codebook.shape[0]
+ if num_codes > self.ort_reg_max_codes:
+ rand_ids = torch.randperm(num_codes, device=device)[
+ : self.ort_reg_max_codes
+ ]
+ codebook = codebook[rand_ids]
+ orthgonal_loss = orthgonal_loss_fn(codebook)
+ loss += self.ort_reg_weight * orthgonal_loss
+ else:
+ loss = torch.tensor([0.0]).type_as(x)
+
+ quantized = self.project_out(quantized)
+ quantized = rearrange(quantized, "b (h w) d -> b d h w", h=H, w=W)
+
+ return quantized, indices, loss
diff --git a/text_recognizer/networks/quantizer/utils.py b/text_recognizer/networks/quantizer/utils.py
new file mode 100644
index 0000000..ec97949
--- /dev/null
+++ b/text_recognizer/networks/quantizer/utils.py
@@ -0,0 +1,50 @@
+"""Helper functions for quantization."""
+from typing import Tuple
+
+import torch
+from torch import einsum, Tensor
+import torch.nn.functional as F
+
+
+def sample_vectors(samples: Tensor, num: int) -> Tensor:
+ """Subsamples a set of vectors."""
+ B, device = samples.shape[0], samples.device
+ if B >= num:
+ indices = torch.randperm(B, device=device)[:num]
+ else:
+ indices = torch.randint(0, B, (num,), device=device)[:num]
+ return samples[indices]
+
+
+def norm(t: Tensor) -> Tensor:
+ """Applies L2-normalization."""
+ return F.normalize(t, p=2, dim=-1)
+
+
+def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None:
+ """Applies exponential moving average."""
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
+
+
+def log(t: Tensor, eps: float = 1e-20) -> Tensor:
+ return torch.log(t.clamp(min=eps))
+
+
+def gumbel_noise(t: Tensor) -> Tensor:
+ noise = torch.zeros_like(t).uniform_(0, 1)
+ return -log(-log(noise))
+
+
+def gumbel_sample(t: Tensor, temperature: float = 1.0, dim: int = -1) -> Tensor:
+ if temperature == 0:
+ return t.argmax(dim=dim)
+ return ((t / temperature) + gumbel_noise(t)).argmax(dim=dim)
+
+
+def orthgonal_loss_fn(t: Tensor) -> Tensor:
+ # eq (2) from https://arxiv.org/abs/2112.00384
+ n = t.shape[0]
+ normed_codes = norm(t)
+ identity = torch.eye(n, device=t.device)
+ cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
+ return ((cosine_sim - identity) ** 2).sum() / (n ** 2)