From 68cda1e6abf76312eb13f45ca3c8565a0b00d745 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Mon, 27 Jun 2022 18:17:16 +0200
Subject: Add quantizer

---
 text_recognizer/networks/quantizer/__init__.py     |   2 +
 .../networks/quantizer/cosine_codebook.py          | 104 +++++++++++++++++++++
 text_recognizer/networks/quantizer/kmeans.py       |  32 +++++++
 text_recognizer/networks/quantizer/quantizer.py    |  69 ++++++++++++++
 text_recognizer/networks/quantizer/utils.py        |  50 ++++++++++
 5 files changed, 257 insertions(+)
 create mode 100644 text_recognizer/networks/quantizer/__init__.py
 create mode 100644 text_recognizer/networks/quantizer/cosine_codebook.py
 create mode 100644 text_recognizer/networks/quantizer/kmeans.py
 create mode 100644 text_recognizer/networks/quantizer/quantizer.py
 create mode 100644 text_recognizer/networks/quantizer/utils.py

diff --git a/text_recognizer/networks/quantizer/__init__.py b/text_recognizer/networks/quantizer/__init__.py
new file mode 100644
index 0000000..9c8685c
--- /dev/null
+++ b/text_recognizer/networks/quantizer/__init__.py
@@ -0,0 +1,2 @@
+from text_recognizer.networks.quantizer.cosine_codebook import CosineSimilarityCodebook
+from text_recognizer.networks.quantizer.quantizer import VectorQuantizer
diff --git a/text_recognizer/networks/quantizer/cosine_codebook.py b/text_recognizer/networks/quantizer/cosine_codebook.py
new file mode 100644
index 0000000..3b6af0f
--- /dev/null
+++ b/text_recognizer/networks/quantizer/cosine_codebook.py
@@ -0,0 +1,104 @@
+"""Codebook module."""
+from typing import Tuple
+
+from einops import rearrange
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+
+from text_recognizer.networks.quantizer.kmeans import kmeans
+from text_recognizer.networks.quantizer.utils import (
+    ema_inplace,
+    norm,
+    sample_vectors,
+    gumbel_sample,
+)
+
+
+class CosineSimilarityCodebook(nn.Module):
+    """Cosine similarity codebook."""
+
+    def __init__(
+        self,
+        dim: int,
+        codebook_size: int,
+        kmeans_init: bool = False,
+        kmeans_iters: int = 10,
+        decay: float = 0.8,
+        eps: float = 1.0e-5,
+        threshold_dead: int = 2,
+        temperature: float = 0.0,
+    ) -> None:
+        super().__init__()
+        self.dim = dim
+        self.codebook_size = codebook_size
+        self.kmeans_init = kmeans_init
+        self.kmeans_iters = kmeans_iters
+        self.decay = decay
+        self.eps = eps
+        self.threshold_dead = threshold_dead
+        self.temperature = temperature
+
+        if not self.kmeans_init:
+            embeddings = norm(torch.randn(self.codebook_size, self.dim))
+        else:
+            embeddings = torch.zeros(self.codebook_size, self.dim)
+        self.register_buffer("initalized", Tensor([not self.kmeans_init]))
+        self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
+        self.register_buffer("embeddings", embeddings)
+
+    def _initalize_embedding(self, data: Tensor) -> None:
+        embeddings, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
+        self.embeddings.data.copy_(embeddings)
+        self.cluster_size.data.copy_(cluster_size)
+        self.initalized.data.copy_(Tensor([True]))
+
+    def _replace(self, samples: Tensor, mask: Tensor) -> None:
+        samples = norm(samples)
+        modified_codebook = torch.where(
+            mask[..., None],
+            sample_vectors(samples, self.codebook_size),
+            self.embeddings,
+        )
+        self.embeddings.data.copy_(modified_codebook)
+
+    def _replace_dead_codes(self, batch_samples: Tensor) -> None:
+        if self.threshold_dead == 0:
+            return
+        dead_codes = self.cluster_size < self.threshold_dead
+        if not torch.any(dead_codes):
+            return
+        batch_samples = rearrange(batch_samples, "... d -> (...) d")
+        self._replace(batch_samples, mask=dead_codes)
+
+    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+        """Quantizes tensor."""
+        shape = x.shape
+        flatten = rearrange(x, "... d -> (...) d")
+        flatten = norm(flatten)
+
+        if not self.initalized:
+            self._initalize_embedding(flatten)
+
+        embeddings = norm(self.embeddings)
+        dist = flatten @ embeddings.t()
+        indices = gumbel_sample(dist, dim=-1, temperature=self.temperature)
+        one_hot = F.one_hot(indices, self.codebook_size).type_as(x)
+        indices = indices.view(*shape[:-1])
+
+        quantized = F.embedding(indices, self.embeddings)
+
+        if self.training:
+            bins = one_hot.sum(0)
+            ema_inplace(self.cluster_size, bins, self.decay)
+            zero_mask = bins == 0
+            bins = bins.masked_fill(zero_mask, 1.0)
+
+            embed_sum = flatten.t() @ one_hot
+            embed_norm = (embed_sum / bins.unsqueeze(0)).t()
+            embed_norm = norm(embed_norm)
+            embed_norm = torch.where(zero_mask[..., None], embeddings, embed_norm)
+            ema_inplace(self.embeddings, embed_norm, self.decay)
+            self._replace_dead_codes(x)
+
+        return quantized, indices
diff --git a/text_recognizer/networks/quantizer/kmeans.py b/text_recognizer/networks/quantizer/kmeans.py
new file mode 100644
index 0000000..a34c381
--- /dev/null
+++ b/text_recognizer/networks/quantizer/kmeans.py
@@ -0,0 +1,32 @@
+"""K-means clustering for embeddings."""
+from typing import Tuple
+
+from einops import repeat
+import torch
+from torch import Tensor
+
+from text_recognizer.networks.quantizer.utils import norm, sample_vectors
+
+
+def kmeans(
+    samples: Tensor, num_clusters: int, num_iters: int = 10
+) -> Tuple[Tensor, Tensor]:
+    """Compute k-means clusters."""
+    D = samples.shape[-1]
+
+    means = sample_vectors(samples, num_clusters)
+
+    for _ in range(num_iters):
+        dists = samples @ means.t()
+        buckets = dists.max(dim=-1).indices
+        bins = torch.bincount(buckets, minlength=num_clusters)
+        zero_mask = bins == 0
+        bins_min_clamped = bins.masked_fill(zero_mask, 1)
+
+        new_means = buckets.new_zeros(num_clusters, D).type_as(samples)
+        new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=D), samples)
+        new_means /= bins_min_clamped[..., None]
+        new_means = norm(new_means)
+        means = torch.where(zero_mask[..., None], means, new_means)
+
+    return means, bins
diff --git a/text_recognizer/networks/quantizer/quantizer.py b/text_recognizer/networks/quantizer/quantizer.py
new file mode 100644
index 0000000..2c07b79
--- /dev/null
+++ b/text_recognizer/networks/quantizer/quantizer.py
@@ -0,0 +1,69 @@
+from typing import Optional, Tuple, Type
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+
+from text_recognizer.networks.quantizer.utils import orthgonal_loss_fn
+
+
+class VectorQuantizer(nn.Module):
+    """Vector quantizer."""
+
+    def __init__(
+        self,
+        input_dim: int,
+        codebook: Type[nn.Module],
+        commitment: float = 1.0,
+        ort_reg_weight: float = 0,
+        ort_reg_max_codes: Optional[int] = None,
+    ) -> None:
+        super().__init__()
+        self.input_dim = input_dim
+        self.codebook = codebook
+        self.commitment = commitment
+        self.ort_reg_weight = ort_reg_weight
+        self.ort_reg_max_codes = ort_reg_max_codes
+        require_projection = self.codebook.dim != self.input_dim
+        self.project_in = (
+            nn.Linear(self.input_dim, self.codebook.dim)
+            if require_projection
+            else nn.Identity()
+        )
+        self.project_out = (
+            nn.Linear(self.codebook.dim, self.input_dim)
+            if require_projection
+            else nn.Identity()
+        )
+
+    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
+        """Quantizes latent vectors."""
+        H, W = x.shape[-2:]
+        device = x.device
+        x = rearrange(x, "b d h w -> b (h w) d")
+        x = self.project_in(x)
+
+        quantized, indices = self.codebook(x)
+
+        if self.training:
+            loss = F.mse_loss(quantized.detach(), x) * self.commitment
+            quantized = x + (quantized - x).detach()
+            if self.ort_reg_weight > 0:
+                codebook = self.codebook.embeddings
+                num_codes = codebook.shape[0]
+                if num_codes > self.ort_reg_max_codes:
+                    rand_ids = torch.randperm(num_codes, device=device)[
+                        : self.ort_reg_max_codes
+                    ]
+                    codebook = codebook[rand_ids]
+                orthgonal_loss = orthgonal_loss_fn(codebook)
+                loss += self.ort_reg_weight * orthgonal_loss
+        else:
+            loss = torch.tensor([0.0]).type_as(x)
+
+        quantized = self.project_out(quantized)
+        quantized = rearrange(quantized, "b (h w) d -> b d h w", h=H, w=W)
+
+        return quantized, indices, loss
diff --git a/text_recognizer/networks/quantizer/utils.py b/text_recognizer/networks/quantizer/utils.py
new file mode 100644
index 0000000..ec97949
--- /dev/null
+++ b/text_recognizer/networks/quantizer/utils.py
@@ -0,0 +1,50 @@
+"""Helper functions for quantization."""
+from typing import Tuple
+
+import torch
+from torch import einsum, Tensor
+import torch.nn.functional as F
+
+
+def sample_vectors(samples: Tensor, num: int) -> Tensor:
+    """Subsamples a set of vectors."""
+    B, device = samples.shape[0], samples.device
+    if B >= num:
+        indices = torch.randperm(B, device=device)[:num]
+    else:
+        indices = torch.randint(0, B, (num,), device=device)[:num]
+    return samples[indices]
+
+
+def norm(t: Tensor) -> Tensor:
+    """Applies L2-normalization."""
+    return F.normalize(t, p=2, dim=-1)
+
+
+def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None:
+    """Applies exponential moving average."""
+    moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
+
+
+def log(t: Tensor, eps: float = 1e-20) -> Tensor:
+    return torch.log(t.clamp(min=eps))
+
+
+def gumbel_noise(t: Tensor) -> Tensor:
+    noise = torch.zeros_like(t).uniform_(0, 1)
+    return -log(-log(noise))
+
+
+def gumbel_sample(t: Tensor, temperature: float = 1.0, dim: int = -1) -> Tensor:
+    if temperature == 0:
+        return t.argmax(dim=dim)
+    return ((t / temperature) + gumbel_noise(t)).argmax(dim=dim)
+
+
+def orthgonal_loss_fn(t: Tensor) -> Tensor:
+    # eq (2) from https://arxiv.org/abs/2112.00384
+    n = t.shape[0]
+    normed_codes = norm(t)
+    identity = torch.eye(n, device=t.device)
+    cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
+    return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
-- 
cgit v1.2.3-70-g09d2