From 0901bb8172fe56caa3eba9e4bf96ae0b164f9292 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Tue, 13 Sep 2022 18:12:35 +0200
Subject: Remove quantizer

---
 text_recognizer/networks/quantizer/__init__.py     |   2 -
 .../networks/quantizer/cosine_codebook.py          | 104 ---------------------
 text_recognizer/networks/quantizer/kmeans.py       |  32 -------
 text_recognizer/networks/quantizer/quantizer.py    |  69 --------------
 text_recognizer/networks/quantizer/utils.py        |  50 ----------
 5 files changed, 257 deletions(-)
 delete mode 100644 text_recognizer/networks/quantizer/__init__.py
 delete mode 100644 text_recognizer/networks/quantizer/cosine_codebook.py
 delete mode 100644 text_recognizer/networks/quantizer/kmeans.py
 delete mode 100644 text_recognizer/networks/quantizer/quantizer.py
 delete mode 100644 text_recognizer/networks/quantizer/utils.py

diff --git a/text_recognizer/networks/quantizer/__init__.py b/text_recognizer/networks/quantizer/__init__.py
deleted file mode 100644
index 9c8685c..0000000
--- a/text_recognizer/networks/quantizer/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from text_recognizer.networks.quantizer.cosine_codebook import CosineSimilarityCodebook
-from text_recognizer.networks.quantizer.quantizer import VectorQuantizer
diff --git a/text_recognizer/networks/quantizer/cosine_codebook.py b/text_recognizer/networks/quantizer/cosine_codebook.py
deleted file mode 100644
index 3b6af0f..0000000
--- a/text_recognizer/networks/quantizer/cosine_codebook.py
+++ /dev/null
@@ -1,104 +0,0 @@
-"""Codebook module."""
-from typing import Tuple
-
-from einops import rearrange
-import torch
-from torch import nn, Tensor
-import torch.nn.functional as F
-
-from text_recognizer.networks.quantizer.kmeans import kmeans
-from text_recognizer.networks.quantizer.utils import (
-    ema_inplace,
-    norm,
-    sample_vectors,
-    gumbel_sample,
-)
-
-
-class CosineSimilarityCodebook(nn.Module):
-    """Cosine similarity codebook."""
-
-    def __init__(
-        self,
-        dim: int,
-        codebook_size: int,
-        kmeans_init: bool = False,
-        kmeans_iters: int = 10,
-        decay: float = 0.8,
-        eps: float = 1.0e-5,
-        threshold_dead: int = 2,
-        temperature: float = 0.0,
-    ) -> None:
-        super().__init__()
-        self.dim = dim
-        self.codebook_size = codebook_size
-        self.kmeans_init = kmeans_init
-        self.kmeans_iters = kmeans_iters
-        self.decay = decay
-        self.eps = eps
-        self.threshold_dead = threshold_dead
-        self.temperature = temperature
-
-        if not self.kmeans_init:
-            embeddings = norm(torch.randn(self.codebook_size, self.dim))
-        else:
-            embeddings = torch.zeros(self.codebook_size, self.dim)
-        self.register_buffer("initalized", Tensor([not self.kmeans_init]))
-        self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
-        self.register_buffer("embeddings", embeddings)
-
-    def _initalize_embedding(self, data: Tensor) -> None:
-        embeddings, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
-        self.embeddings.data.copy_(embeddings)
-        self.cluster_size.data.copy_(cluster_size)
-        self.initalized.data.copy_(Tensor([True]))
-
-    def _replace(self, samples: Tensor, mask: Tensor) -> None:
-        samples = norm(samples)
-        modified_codebook = torch.where(
-            mask[..., None],
-            sample_vectors(samples, self.codebook_size),
-            self.embeddings,
-        )
-        self.embeddings.data.copy_(modified_codebook)
-
-    def _replace_dead_codes(self, batch_samples: Tensor) -> None:
-        if self.threshold_dead == 0:
-            return
-        dead_codes = self.cluster_size < self.threshold_dead
-        if not torch.any(dead_codes):
-            return
-        batch_samples = rearrange(batch_samples, "... d -> (...) d")
-        self._replace(batch_samples, mask=dead_codes)
-
-    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
-        """Quantizes tensor."""
-        shape = x.shape
-        flatten = rearrange(x, "... d -> (...) d")
-        flatten = norm(flatten)
-
-        if not self.initalized:
-            self._initalize_embedding(flatten)
-
-        embeddings = norm(self.embeddings)
-        dist = flatten @ embeddings.t()
-        indices = gumbel_sample(dist, dim=-1, temperature=self.temperature)
-        one_hot = F.one_hot(indices, self.codebook_size).type_as(x)
-        indices = indices.view(*shape[:-1])
-
-        quantized = F.embedding(indices, self.embeddings)
-
-        if self.training:
-            bins = one_hot.sum(0)
-            ema_inplace(self.cluster_size, bins, self.decay)
-            zero_mask = bins == 0
-            bins = bins.masked_fill(zero_mask, 1.0)
-
-            embed_sum = flatten.t() @ one_hot
-            embed_norm = (embed_sum / bins.unsqueeze(0)).t()
-            embed_norm = norm(embed_norm)
-            embed_norm = torch.where(zero_mask[..., None], embeddings, embed_norm)
-            ema_inplace(self.embeddings, embed_norm, self.decay)
-            self._replace_dead_codes(x)
-
-        return quantized, indices
diff --git a/text_recognizer/networks/quantizer/kmeans.py b/text_recognizer/networks/quantizer/kmeans.py
deleted file mode 100644
index a34c381..0000000
--- a/text_recognizer/networks/quantizer/kmeans.py
+++ /dev/null
@@ -1,32 +0,0 @@
-"""K-means clustering for embeddings."""
-from typing import Tuple
-
-from einops import repeat
-import torch
-from torch import Tensor
-
-from text_recognizer.networks.quantizer.utils import norm, sample_vectors
-
-
-def kmeans(
-    samples: Tensor, num_clusters: int, num_iters: int = 10
-) -> Tuple[Tensor, Tensor]:
-    """Compute k-means clusters."""
-    D = samples.shape[-1]
-
-    means = sample_vectors(samples, num_clusters)
-
-    for _ in range(num_iters):
-        dists = samples @ means.t()
-        buckets = dists.max(dim=-1).indices
-        bins = torch.bincount(buckets, minlength=num_clusters)
-        zero_mask = bins == 0
-        bins_min_clamped = bins.masked_fill(zero_mask, 1)
-
-        new_means = buckets.new_zeros(num_clusters, D).type_as(samples)
-        new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=D), samples)
-        new_means /= bins_min_clamped[..., None]
-        new_means = norm(new_means)
-        means = torch.where(zero_mask[..., None], means, new_means)
-
-    return means, bins
diff --git a/text_recognizer/networks/quantizer/quantizer.py b/text_recognizer/networks/quantizer/quantizer.py
deleted file mode 100644
index 2c07b79..0000000
--- a/text_recognizer/networks/quantizer/quantizer.py
+++ /dev/null
@@ -1,69 +0,0 @@
-from typing import Optional, Tuple, Type
-
-from einops import rearrange
-import torch
-from torch import nn
-from torch import Tensor
-import torch.nn.functional as F
-
-from text_recognizer.networks.quantizer.utils import orthgonal_loss_fn
-
-
-class VectorQuantizer(nn.Module):
-    """Vector quantizer."""
-
-    def __init__(
-        self,
-        input_dim: int,
-        codebook: Type[nn.Module],
-        commitment: float = 1.0,
-        ort_reg_weight: float = 0,
-        ort_reg_max_codes: Optional[int] = None,
-    ) -> None:
-        super().__init__()
-        self.input_dim = input_dim
-        self.codebook = codebook
-        self.commitment = commitment
-        self.ort_reg_weight = ort_reg_weight
-        self.ort_reg_max_codes = ort_reg_max_codes
-        require_projection = self.codebook.dim != self.input_dim
-        self.project_in = (
-            nn.Linear(self.input_dim, self.codebook.dim)
-            if require_projection
-            else nn.Identity()
-        )
-        self.project_out = (
-            nn.Linear(self.codebook.dim, self.input_dim)
-            if require_projection
-            else nn.Identity()
-        )
-
-    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
-        """Quantizes latent vectors."""
-        H, W = x.shape[-2:]
-        device = x.device
-        x = rearrange(x, "b d h w -> b (h w) d")
-        x = self.project_in(x)
-
-        quantized, indices = self.codebook(x)
-
-        if self.training:
-            loss = F.mse_loss(quantized.detach(), x) * self.commitment
-            quantized = x + (quantized - x).detach()
-            if self.ort_reg_weight > 0:
-                codebook = self.codebook.embeddings
-                num_codes = codebook.shape[0]
-                if num_codes > self.ort_reg_max_codes:
-                    rand_ids = torch.randperm(num_codes, device=device)[
-                        : self.ort_reg_max_codes
-                    ]
-                    codebook = codebook[rand_ids]
-                orthgonal_loss = orthgonal_loss_fn(codebook)
-                loss += self.ort_reg_weight * orthgonal_loss
-        else:
-            loss = torch.tensor([0.0]).type_as(x)
-
-        quantized = self.project_out(quantized)
-        quantized = rearrange(quantized, "b (h w) d -> b d h w", h=H, w=W)
-
-        return quantized, indices, loss
diff --git a/text_recognizer/networks/quantizer/utils.py b/text_recognizer/networks/quantizer/utils.py
deleted file mode 100644
index ec97949..0000000
--- a/text_recognizer/networks/quantizer/utils.py
+++ /dev/null
@@ -1,50 +0,0 @@
-"""Helper functions for quantization."""
-from typing import Tuple
-
-import torch
-from torch import einsum, Tensor
-import torch.nn.functional as F
-
-
-def sample_vectors(samples: Tensor, num: int) -> Tensor:
-    """Subsamples a set of vectors."""
-    B, device = samples.shape[0], samples.device
-    if B >= num:
-        indices = torch.randperm(B, device=device)[:num]
-    else:
-        indices = torch.randint(0, B, (num,), device=device)[:num]
-    return samples[indices]
-
-
-def norm(t: Tensor) -> Tensor:
-    """Applies L2-normalization."""
-    return F.normalize(t, p=2, dim=-1)
-
-
-def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None:
-    """Applies exponential moving average."""
-    moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
-
-
-def log(t: Tensor, eps: float = 1e-20) -> Tensor:
-    return torch.log(t.clamp(min=eps))
-
-
-def gumbel_noise(t: Tensor) -> Tensor:
-    noise = torch.zeros_like(t).uniform_(0, 1)
-    return -log(-log(noise))
-
-
-def gumbel_sample(t: Tensor, temperature: float = 1.0, dim: int = -1) -> Tensor:
-    if temperature == 0:
-        return t.argmax(dim=dim)
-    return ((t / temperature) + gumbel_noise(t)).argmax(dim=dim)
-
-
-def orthgonal_loss_fn(t: Tensor) -> Tensor:
-    # eq (2) from https://arxiv.org/abs/2112.00384
-    n = t.shape[0]
-    normed_codes = norm(t)
-    identity = torch.eye(n, device=t.device)
-    cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
-    return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
-- 
cgit v1.2.3-70-g09d2