diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-13 18:12:35 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-13 18:12:35 +0200 |
commit | 0901bb8172fe56caa3eba9e4bf96ae0b164f9292 (patch) | |
tree | ad1b5964af91a5982fed59715f058586cd28f60d /text_recognizer | |
parent | 7be90f5f101d7ace7ff07180950dac4c11086ec1 (diff) |
Remove quantizer
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/quantizer/__init__.py | 2 | ||||
-rw-r--r-- | text_recognizer/networks/quantizer/cosine_codebook.py | 104 | ||||
-rw-r--r-- | text_recognizer/networks/quantizer/kmeans.py | 32 | ||||
-rw-r--r-- | text_recognizer/networks/quantizer/quantizer.py | 69 | ||||
-rw-r--r-- | text_recognizer/networks/quantizer/utils.py | 50 |
5 files changed, 0 insertions, 257 deletions
diff --git a/text_recognizer/networks/quantizer/__init__.py b/text_recognizer/networks/quantizer/__init__.py deleted file mode 100644 index 9c8685c..0000000 --- a/text_recognizer/networks/quantizer/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from text_recognizer.networks.quantizer.cosine_codebook import CosineSimilarityCodebook -from text_recognizer.networks.quantizer.quantizer import VectorQuantizer diff --git a/text_recognizer/networks/quantizer/cosine_codebook.py b/text_recognizer/networks/quantizer/cosine_codebook.py deleted file mode 100644 index 3b6af0f..0000000 --- a/text_recognizer/networks/quantizer/cosine_codebook.py +++ /dev/null @@ -1,104 +0,0 @@ -"""Codebook module.""" -from typing import Tuple - -from einops import rearrange -import torch -from torch import nn, Tensor -import torch.nn.functional as F - -from text_recognizer.networks.quantizer.kmeans import kmeans -from text_recognizer.networks.quantizer.utils import ( - ema_inplace, - norm, - sample_vectors, - gumbel_sample, -) - - -class CosineSimilarityCodebook(nn.Module): - """Cosine similarity codebook.""" - - def __init__( - self, - dim: int, - codebook_size: int, - kmeans_init: bool = False, - kmeans_iters: int = 10, - decay: float = 0.8, - eps: float = 1.0e-5, - threshold_dead: int = 2, - temperature: float = 0.0, - ) -> None: - super().__init__() - self.dim = dim - self.codebook_size = codebook_size - self.kmeans_init = kmeans_init - self.kmeans_iters = kmeans_iters - self.decay = decay - self.eps = eps - self.threshold_dead = threshold_dead - self.temperature = temperature - - if not self.kmeans_init: - embeddings = norm(torch.randn(self.codebook_size, self.dim)) - else: - embeddings = torch.zeros(self.codebook_size, self.dim) - self.register_buffer("initalized", Tensor([not self.kmeans_init])) - self.register_buffer("cluster_size", torch.zeros(self.codebook_size)) - self.register_buffer("embeddings", embeddings) - - def _initalize_embedding(self, data: Tensor) -> None: - embeddings, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) - self.embeddings.data.copy_(embeddings) - self.cluster_size.data.copy_(cluster_size) - self.initalized.data.copy_(Tensor([True])) - - def _replace(self, samples: Tensor, mask: Tensor) -> None: - samples = norm(samples) - modified_codebook = torch.where( - mask[..., None], - sample_vectors(samples, self.codebook_size), - self.embeddings, - ) - self.embeddings.data.copy_(modified_codebook) - - def _replace_dead_codes(self, batch_samples: Tensor) -> None: - if self.threshold_dead == 0: - return - dead_codes = self.cluster_size < self.threshold_dead - if not torch.any(dead_codes): - return - batch_samples = rearrange(batch_samples, "... d -> (...) d") - self._replace(batch_samples, mask=dead_codes) - - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: - """Quantizes tensor.""" - shape = x.shape - flatten = rearrange(x, "... d -> (...) d") - flatten = norm(flatten) - - if not self.initalized: - self._initalize_embedding(flatten) - - embeddings = norm(self.embeddings) - dist = flatten @ embeddings.t() - indices = gumbel_sample(dist, dim=-1, temperature=self.temperature) - one_hot = F.one_hot(indices, self.codebook_size).type_as(x) - indices = indices.view(*shape[:-1]) - - quantized = F.embedding(indices, self.embeddings) - - if self.training: - bins = one_hot.sum(0) - ema_inplace(self.cluster_size, bins, self.decay) - zero_mask = bins == 0 - bins = bins.masked_fill(zero_mask, 1.0) - - embed_sum = flatten.t() @ one_hot - embed_norm = (embed_sum / bins.unsqueeze(0)).t() - embed_norm = norm(embed_norm) - embed_norm = torch.where(zero_mask[..., None], embeddings, embed_norm) - ema_inplace(self.embeddings, embed_norm, self.decay) - self._replace_dead_codes(x) - - return quantized, indices diff --git a/text_recognizer/networks/quantizer/kmeans.py b/text_recognizer/networks/quantizer/kmeans.py deleted file mode 100644 index a34c381..0000000 --- a/text_recognizer/networks/quantizer/kmeans.py +++ /dev/null @@ -1,32 +0,0 @@ -"""K-means clustering for embeddings.""" -from typing import Tuple - -from einops import repeat -import torch -from torch import Tensor - -from text_recognizer.networks.quantizer.utils import norm, sample_vectors - - -def kmeans( - samples: Tensor, num_clusters: int, num_iters: int = 10 -) -> Tuple[Tensor, Tensor]: - """Compute k-means clusters.""" - D = samples.shape[-1] - - means = sample_vectors(samples, num_clusters) - - for _ in range(num_iters): - dists = samples @ means.t() - buckets = dists.max(dim=-1).indices - bins = torch.bincount(buckets, minlength=num_clusters) - zero_mask = bins == 0 - bins_min_clamped = bins.masked_fill(zero_mask, 1) - - new_means = buckets.new_zeros(num_clusters, D).type_as(samples) - new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=D), samples) - new_means /= bins_min_clamped[..., None] - new_means = norm(new_means) - means = torch.where(zero_mask[..., None], means, new_means) - - return means, bins diff --git a/text_recognizer/networks/quantizer/quantizer.py b/text_recognizer/networks/quantizer/quantizer.py deleted file mode 100644 index 2c07b79..0000000 --- a/text_recognizer/networks/quantizer/quantizer.py +++ /dev/null @@ -1,69 +0,0 @@ -from typing import Optional, Tuple, Type - -from einops import rearrange -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - -from text_recognizer.networks.quantizer.utils import orthgonal_loss_fn - - -class VectorQuantizer(nn.Module): - """Vector quantizer.""" - - def __init__( - self, - input_dim: int, - codebook: Type[nn.Module], - commitment: float = 1.0, - ort_reg_weight: float = 0, - ort_reg_max_codes: Optional[int] = None, - ) -> None: - super().__init__() - self.input_dim = input_dim - self.codebook = codebook - self.commitment = commitment - self.ort_reg_weight = ort_reg_weight - self.ort_reg_max_codes = ort_reg_max_codes - require_projection = self.codebook.dim != self.input_dim - self.project_in = ( - nn.Linear(self.input_dim, self.codebook.dim) - if require_projection - else nn.Identity() - ) - self.project_out = ( - nn.Linear(self.codebook.dim, self.input_dim) - if require_projection - else nn.Identity() - ) - - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - """Quantizes latent vectors.""" - H, W = x.shape[-2:] - device = x.device - x = rearrange(x, "b d h w -> b (h w) d") - x = self.project_in(x) - - quantized, indices = self.codebook(x) - - if self.training: - loss = F.mse_loss(quantized.detach(), x) * self.commitment - quantized = x + (quantized - x).detach() - if self.ort_reg_weight > 0: - codebook = self.codebook.embeddings - num_codes = codebook.shape[0] - if num_codes > self.ort_reg_max_codes: - rand_ids = torch.randperm(num_codes, device=device)[ - : self.ort_reg_max_codes - ] - codebook = codebook[rand_ids] - orthgonal_loss = orthgonal_loss_fn(codebook) - loss += self.ort_reg_weight * orthgonal_loss - else: - loss = torch.tensor([0.0]).type_as(x) - - quantized = self.project_out(quantized) - quantized = rearrange(quantized, "b (h w) d -> b d h w", h=H, w=W) - - return quantized, indices, loss diff --git a/text_recognizer/networks/quantizer/utils.py b/text_recognizer/networks/quantizer/utils.py deleted file mode 100644 index ec97949..0000000 --- a/text_recognizer/networks/quantizer/utils.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Helper functions for quantization.""" -from typing import Tuple - -import torch -from torch import einsum, Tensor -import torch.nn.functional as F - - -def sample_vectors(samples: Tensor, num: int) -> Tensor: - """Subsamples a set of vectors.""" - B, device = samples.shape[0], samples.device - if B >= num: - indices = torch.randperm(B, device=device)[:num] - else: - indices = torch.randint(0, B, (num,), device=device)[:num] - return samples[indices] - - -def norm(t: Tensor) -> Tensor: - """Applies L2-normalization.""" - return F.normalize(t, p=2, dim=-1) - - -def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None: - """Applies exponential moving average.""" - moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) - - -def log(t: Tensor, eps: float = 1e-20) -> Tensor: - return torch.log(t.clamp(min=eps)) - - -def gumbel_noise(t: Tensor) -> Tensor: - noise = torch.zeros_like(t).uniform_(0, 1) - return -log(-log(noise)) - - -def gumbel_sample(t: Tensor, temperature: float = 1.0, dim: int = -1) -> Tensor: - if temperature == 0: - return t.argmax(dim=dim) - return ((t / temperature) + gumbel_noise(t)).argmax(dim=dim) - - -def orthgonal_loss_fn(t: Tensor) -> Tensor: - # eq (2) from https://arxiv.org/abs/2112.00384 - n = t.shape[0] - normed_codes = norm(t) - identity = torch.eye(n, device=t.device) - cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes) - return ((cosine_sim - identity) ** 2).sum() / (n ** 2) |