summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/quantizer/codebook.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/quantizer/codebook.py')
-rw-r--r--text_recognizer/networks/quantizer/codebook.py96
1 files changed, 96 insertions, 0 deletions
diff --git a/text_recognizer/networks/quantizer/codebook.py b/text_recognizer/networks/quantizer/codebook.py
new file mode 100644
index 0000000..cb9bc59
--- /dev/null
+++ b/text_recognizer/networks/quantizer/codebook.py
@@ -0,0 +1,96 @@
+"""Codebook module."""
+from typing import Tuple
+
+import attr
+from einops import rearrange
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+
+from text_recognizer.networks.quantizer.kmeans import kmeans
+from text_recognizer.networks.quantizer.utils import (
+ ema_inplace,
+ norm,
+ sample_vectors,
+)
+
+
+@attr.s(eq=False)
+class CosineSimilarityCodebook(nn.Module):
+ """Cosine similarity codebook."""
+
+ dim: int = attr.ib()
+ codebook_size: int = attr.ib()
+ kmeans_init: bool = attr.ib(default=False)
+ kmeans_iters: int = attr.ib(default=10)
+ decay: float = attr.ib(default=0.8)
+ eps: float = attr.ib(default=1.0e-5)
+ threshold_dead: int = attr.ib(default=2)
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ def __attrs_post_init__(self) -> None:
+ if not self.kmeans_init:
+ embeddings = norm(torch.randn(self.codebook_size, self.dim))
+ else:
+ embeddings = torch.zeros(self.codebook_size, self.dim)
+ self.register_buffer("initalized", Tensor([not self.kmeans_init]))
+ self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
+ self.register_buffer("embeddings", embeddings)
+
+ def _initalize_embedding(self, data: Tensor) -> None:
+ embeddings, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
+ self.embeddings.data.copy_(embeddings)
+ self.cluster_size.data.copy_(cluster_size)
+ self.initalized.data.copy_(Tensor([True]))
+
+ def _replace(self, samples: Tensor, mask: Tensor) -> None:
+ samples = norm(samples)
+ modified_codebook = torch.where(
+ mask[..., None],
+ sample_vectors(samples, self.codebook_size),
+ self.embeddings,
+ )
+ self.embeddings.data.copy_(modified_codebook)
+
+ def _replace_dead_codes(self, batch_samples: Tensor) -> None:
+ if self.threshold_dead == 0:
+ return
+ dead_codes = self.cluster_size < self.threshold_dead
+ if not torch.any(dead_codes):
+ return
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
+ self._replace(batch_samples, mask=dead_codes)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Quantizes tensor."""
+ shape = x.shape
+ flatten = rearrange(x, "... d -> (...) d")
+ flatten = norm(flatten)
+
+ if not self.initalized:
+ self._initalize_embedding(flatten)
+
+ embeddings = norm(self.embeddings)
+ dist = flatten @ embeddings.t()
+ indices = dist.max(dim=-1).indices
+ one_hot = F.one_hot(indices, self.codebook_size).type_as(x)
+ indices = indices.view(*shape[:-1])
+
+ quantized = F.embedding(indices, self.embeddings)
+
+ if self.training:
+ bins = one_hot.sum(0)
+ ema_inplace(self.cluster_size, bins, self.decay)
+ zero_mask = bins == 0
+ bins = bins.masked_fill(zero_mask, 1.0)
+
+ embed_sum = flatten.t() @ one_hot
+ embed_norm = (embed_sum / bins.unsqueeze(0)).t()
+ embed_norm = norm(embed_norm)
+ embed_norm = torch.where(zero_mask[..., None], embeddings, embed_norm)
+ ema_inplace(self.embeddings, embed_norm, self.decay)
+ self._replace_dead_codes(x)
+
+ return quantized, indices