From 2417288c9fe96264da708ce8d13ac7bc2faf83e3 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 17 Nov 2021 22:42:58 +0100 Subject: Add new quantizer --- text_recognizer/networks/quantizer/__init__.py | 0 text_recognizer/networks/quantizer/codebook.py | 96 +++++++++++++++++++++++++ text_recognizer/networks/quantizer/kmeans.py | 32 +++++++++ text_recognizer/networks/quantizer/quantizer.py | 59 +++++++++++++++ text_recognizer/networks/quantizer/utils.py | 26 +++++++ 5 files changed, 213 insertions(+) create mode 100644 text_recognizer/networks/quantizer/__init__.py create mode 100644 text_recognizer/networks/quantizer/codebook.py create mode 100644 text_recognizer/networks/quantizer/kmeans.py create mode 100644 text_recognizer/networks/quantizer/quantizer.py create mode 100644 text_recognizer/networks/quantizer/utils.py (limited to 'text_recognizer/networks/quantizer') diff --git a/text_recognizer/networks/quantizer/__init__.py b/text_recognizer/networks/quantizer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/text_recognizer/networks/quantizer/codebook.py b/text_recognizer/networks/quantizer/codebook.py new file mode 100644 index 0000000..cb9bc59 --- /dev/null +++ b/text_recognizer/networks/quantizer/codebook.py @@ -0,0 +1,96 @@ +"""Codebook module.""" +from typing import Tuple + +import attr +from einops import rearrange +import torch +from torch import nn, Tensor +import torch.nn.functional as F + +from text_recognizer.networks.quantizer.kmeans import kmeans +from text_recognizer.networks.quantizer.utils import ( + ema_inplace, + norm, + sample_vectors, +) + + +@attr.s(eq=False) +class CosineSimilarityCodebook(nn.Module): + """Cosine similarity codebook.""" + + dim: int = attr.ib() + codebook_size: int = attr.ib() + kmeans_init: bool = attr.ib(default=False) + kmeans_iters: int = attr.ib(default=10) + decay: float = attr.ib(default=0.8) + eps: float = attr.ib(default=1.0e-5) + threshold_dead: int = attr.ib(default=2) + + def __attrs_pre_init__(self) -> None: + super().__init__() + + def __attrs_post_init__(self) -> None: + if not self.kmeans_init: + embeddings = norm(torch.randn(self.codebook_size, self.dim)) + else: + embeddings = torch.zeros(self.codebook_size, self.dim) + self.register_buffer("initalized", Tensor([not self.kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(self.codebook_size)) + self.register_buffer("embeddings", embeddings) + + def _initalize_embedding(self, data: Tensor) -> None: + embeddings, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embeddings.data.copy_(embeddings) + self.cluster_size.data.copy_(cluster_size) + self.initalized.data.copy_(Tensor([True])) + + def _replace(self, samples: Tensor, mask: Tensor) -> None: + samples = norm(samples) + modified_codebook = torch.where( + mask[..., None], + sample_vectors(samples, self.codebook_size), + self.embeddings, + ) + self.embeddings.data.copy_(modified_codebook) + + def _replace_dead_codes(self, batch_samples: Tensor) -> None: + if self.threshold_dead == 0: + return + dead_codes = self.cluster_size < self.threshold_dead + if not torch.any(dead_codes): + return + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self._replace(batch_samples, mask=dead_codes) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Quantizes tensor.""" + shape = x.shape + flatten = rearrange(x, "... d -> (...) d") + flatten = norm(flatten) + + if not self.initalized: + self._initalize_embedding(flatten) + + embeddings = norm(self.embeddings) + dist = flatten @ embeddings.t() + indices = dist.max(dim=-1).indices + one_hot = F.one_hot(indices, self.codebook_size).type_as(x) + indices = indices.view(*shape[:-1]) + + quantized = F.embedding(indices, self.embeddings) + + if self.training: + bins = one_hot.sum(0) + ema_inplace(self.cluster_size, bins, self.decay) + zero_mask = bins == 0 + bins = bins.masked_fill(zero_mask, 1.0) + + embed_sum = flatten.t() @ one_hot + embed_norm = (embed_sum / bins.unsqueeze(0)).t() + embed_norm = norm(embed_norm) + embed_norm = torch.where(zero_mask[..., None], embeddings, embed_norm) + ema_inplace(self.embeddings, embed_norm, self.decay) + self._replace_dead_codes(x) + + return quantized, indices diff --git a/text_recognizer/networks/quantizer/kmeans.py b/text_recognizer/networks/quantizer/kmeans.py new file mode 100644 index 0000000..a34c381 --- /dev/null +++ b/text_recognizer/networks/quantizer/kmeans.py @@ -0,0 +1,32 @@ +"""K-means clustering for embeddings.""" +from typing import Tuple + +from einops import repeat +import torch +from torch import Tensor + +from text_recognizer.networks.quantizer.utils import norm, sample_vectors + + +def kmeans( + samples: Tensor, num_clusters: int, num_iters: int = 10 +) -> Tuple[Tensor, Tensor]: + """Compute k-means clusters.""" + D = samples.shape[-1] + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + dists = samples @ means.t() + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, D).type_as(samples) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=D), samples) + new_means /= bins_min_clamped[..., None] + new_means = norm(new_means) + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins diff --git a/text_recognizer/networks/quantizer/quantizer.py b/text_recognizer/networks/quantizer/quantizer.py new file mode 100644 index 0000000..3e8f0b2 --- /dev/null +++ b/text_recognizer/networks/quantizer/quantizer.py @@ -0,0 +1,59 @@ +"""Implementation of a Vector Quantized Variational AutoEncoder. + +Reference: +https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py +""" +from typing import Tuple, Type + +import attr +from einops import rearrange +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F + + +@attr.s(eq=False) +class VectorQuantizer(nn.Module): + """Vector quantizer.""" + + input_dim: int = attr.ib() + codebook: Type[nn.Module] = attr.ib() + commitment: float = attr.ib(default=1.0) + project_in: nn.Linear = attr.ib(default=None, init=False) + project_out: nn.Linear = attr.ib(default=None, init=False) + + def __attrs_pre_init__(self) -> None: + super().__init__() + + def __attrs_post_init__(self) -> None: + require_projection = self.codebook.dim != self.input_dim + self.project_in = ( + nn.Linear(self.input_dim, self.codebook.dim) + if require_projection + else nn.Identity() + ) + self.project_out = ( + nn.Linear(self.codebook.dim, self.input_dim) + if require_projection + else nn.Identity() + ) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """Quantizes latent vectors.""" + H, W = x.shape[-2:] + x = rearrange(x, "b d h w -> b (h w) d") + x = self.project_in(x) + + quantized, indices = self.codebook(x) + + if self.training: + commitment_loss = F.mse_loss(quantized.detach(), x) * self.commitment + quantized = x + (quantized - x).detach() + else: + commitment_loss = torch.tensor([0.0]).type_as(x) + + quantized = self.project_out(quantized) + quantized = rearrange(quantized, "b (h w) d -> b d h w", h=H, w=W) + + return quantized, indices, commitment_loss diff --git a/text_recognizer/networks/quantizer/utils.py b/text_recognizer/networks/quantizer/utils.py new file mode 100644 index 0000000..0502d49 --- /dev/null +++ b/text_recognizer/networks/quantizer/utils.py @@ -0,0 +1,26 @@ +"""Helper functions for quantization.""" +from typing import Tuple + +import torch +from torch import Tensor +import torch.nn.functional as F + + +def sample_vectors(samples: Tensor, num: int) -> Tensor: + """Subsamples a set of vectors.""" + B, device = samples.shape[0], samples.device + if B >= num: + indices = torch.randperm(B, device=device)[:num] + else: + indices = torch.randint(0, B, (num,), device=device)[:num] + return samples[indices] + + +def norm(t: Tensor) -> Tensor: + """Applies L2-normalization.""" + return F.normalize(t, p=2, dim=-1) + + +def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None: + """Applies exponential moving average.""" + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) -- cgit v1.2.3-70-g09d2