diff options
Diffstat (limited to 'text_recognizer/networks/quantizer')
| -rw-r--r-- | text_recognizer/networks/quantizer/__init__.py | 0 | ||||
| -rw-r--r-- | text_recognizer/networks/quantizer/codebook.py | 96 | ||||
| -rw-r--r-- | text_recognizer/networks/quantizer/kmeans.py | 32 | ||||
| -rw-r--r-- | text_recognizer/networks/quantizer/quantizer.py | 59 | ||||
| -rw-r--r-- | text_recognizer/networks/quantizer/utils.py | 26 | 
5 files changed, 0 insertions, 213 deletions
| diff --git a/text_recognizer/networks/quantizer/__init__.py b/text_recognizer/networks/quantizer/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/text_recognizer/networks/quantizer/__init__.py +++ /dev/null diff --git a/text_recognizer/networks/quantizer/codebook.py b/text_recognizer/networks/quantizer/codebook.py deleted file mode 100644 index cb9bc59..0000000 --- a/text_recognizer/networks/quantizer/codebook.py +++ /dev/null @@ -1,96 +0,0 @@ -"""Codebook module.""" -from typing import Tuple - -import attr -from einops import rearrange -import torch -from torch import nn, Tensor -import torch.nn.functional as F - -from text_recognizer.networks.quantizer.kmeans import kmeans -from text_recognizer.networks.quantizer.utils import ( -    ema_inplace, -    norm, -    sample_vectors, -) - - -@attr.s(eq=False) -class CosineSimilarityCodebook(nn.Module): -    """Cosine similarity codebook.""" - -    dim: int = attr.ib() -    codebook_size: int = attr.ib() -    kmeans_init: bool = attr.ib(default=False) -    kmeans_iters: int = attr.ib(default=10) -    decay: float = attr.ib(default=0.8) -    eps: float = attr.ib(default=1.0e-5) -    threshold_dead: int = attr.ib(default=2) - -    def __attrs_pre_init__(self) -> None: -        super().__init__() - -    def __attrs_post_init__(self) -> None: -        if not self.kmeans_init: -            embeddings = norm(torch.randn(self.codebook_size, self.dim)) -        else: -            embeddings = torch.zeros(self.codebook_size, self.dim) -        self.register_buffer("initalized", Tensor([not self.kmeans_init])) -        self.register_buffer("cluster_size", torch.zeros(self.codebook_size)) -        self.register_buffer("embeddings", embeddings) - -    def _initalize_embedding(self, data: Tensor) -> None: -        embeddings, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) -        self.embeddings.data.copy_(embeddings) -        self.cluster_size.data.copy_(cluster_size) -        self.initalized.data.copy_(Tensor([True])) - -    def _replace(self, samples: Tensor, mask: Tensor) -> None: -        samples = norm(samples) -        modified_codebook = torch.where( -            mask[..., None], -            sample_vectors(samples, self.codebook_size), -            self.embeddings, -        ) -        self.embeddings.data.copy_(modified_codebook) - -    def _replace_dead_codes(self, batch_samples: Tensor) -> None: -        if self.threshold_dead == 0: -            return -        dead_codes = self.cluster_size < self.threshold_dead -        if not torch.any(dead_codes): -            return -        batch_samples = rearrange(batch_samples, "... d -> (...) d") -        self._replace(batch_samples, mask=dead_codes) - -    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: -        """Quantizes tensor.""" -        shape = x.shape -        flatten = rearrange(x, "... d -> (...) d") -        flatten = norm(flatten) - -        if not self.initalized: -            self._initalize_embedding(flatten) - -        embeddings = norm(self.embeddings) -        dist = flatten @ embeddings.t() -        indices = dist.max(dim=-1).indices -        one_hot = F.one_hot(indices, self.codebook_size).type_as(x) -        indices = indices.view(*shape[:-1]) - -        quantized = F.embedding(indices, self.embeddings) - -        if self.training: -            bins = one_hot.sum(0) -            ema_inplace(self.cluster_size, bins, self.decay) -            zero_mask = bins == 0 -            bins = bins.masked_fill(zero_mask, 1.0) - -            embed_sum = flatten.t() @ one_hot -            embed_norm = (embed_sum / bins.unsqueeze(0)).t() -            embed_norm = norm(embed_norm) -            embed_norm = torch.where(zero_mask[..., None], embeddings, embed_norm) -            ema_inplace(self.embeddings, embed_norm, self.decay) -            self._replace_dead_codes(x) - -        return quantized, indices diff --git a/text_recognizer/networks/quantizer/kmeans.py b/text_recognizer/networks/quantizer/kmeans.py deleted file mode 100644 index a34c381..0000000 --- a/text_recognizer/networks/quantizer/kmeans.py +++ /dev/null @@ -1,32 +0,0 @@ -"""K-means clustering for embeddings.""" -from typing import Tuple - -from einops import repeat -import torch -from torch import Tensor - -from text_recognizer.networks.quantizer.utils import norm, sample_vectors - - -def kmeans( -    samples: Tensor, num_clusters: int, num_iters: int = 10 -) -> Tuple[Tensor, Tensor]: -    """Compute k-means clusters.""" -    D = samples.shape[-1] - -    means = sample_vectors(samples, num_clusters) - -    for _ in range(num_iters): -        dists = samples @ means.t() -        buckets = dists.max(dim=-1).indices -        bins = torch.bincount(buckets, minlength=num_clusters) -        zero_mask = bins == 0 -        bins_min_clamped = bins.masked_fill(zero_mask, 1) - -        new_means = buckets.new_zeros(num_clusters, D).type_as(samples) -        new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=D), samples) -        new_means /= bins_min_clamped[..., None] -        new_means = norm(new_means) -        means = torch.where(zero_mask[..., None], means, new_means) - -    return means, bins diff --git a/text_recognizer/networks/quantizer/quantizer.py b/text_recognizer/networks/quantizer/quantizer.py deleted file mode 100644 index 3e8f0b2..0000000 --- a/text_recognizer/networks/quantizer/quantizer.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Implementation of a Vector Quantized Variational AutoEncoder. - -Reference: -https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py -""" -from typing import Tuple, Type - -import attr -from einops import rearrange -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - - -@attr.s(eq=False) -class VectorQuantizer(nn.Module): -    """Vector quantizer.""" - -    input_dim: int = attr.ib() -    codebook: Type[nn.Module] = attr.ib() -    commitment: float = attr.ib(default=1.0) -    project_in: nn.Linear = attr.ib(default=None, init=False) -    project_out: nn.Linear = attr.ib(default=None, init=False) - -    def __attrs_pre_init__(self) -> None: -        super().__init__() - -    def __attrs_post_init__(self) -> None: -        require_projection = self.codebook.dim != self.input_dim -        self.project_in = ( -            nn.Linear(self.input_dim, self.codebook.dim) -            if require_projection -            else nn.Identity() -        ) -        self.project_out = ( -            nn.Linear(self.codebook.dim, self.input_dim) -            if require_projection -            else nn.Identity() -        ) - -    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: -        """Quantizes latent vectors.""" -        H, W = x.shape[-2:] -        x = rearrange(x, "b d h w -> b (h w) d") -        x = self.project_in(x) - -        quantized, indices = self.codebook(x) - -        if self.training: -            commitment_loss = F.mse_loss(quantized.detach(), x) * self.commitment -            quantized = x + (quantized - x).detach() -        else: -            commitment_loss = torch.tensor([0.0]).type_as(x) - -        quantized = self.project_out(quantized) -        quantized = rearrange(quantized, "b (h w) d -> b d h w", h=H, w=W) - -        return quantized, indices, commitment_loss diff --git a/text_recognizer/networks/quantizer/utils.py b/text_recognizer/networks/quantizer/utils.py deleted file mode 100644 index 0502d49..0000000 --- a/text_recognizer/networks/quantizer/utils.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Helper functions for quantization.""" -from typing import Tuple - -import torch -from torch import Tensor -import torch.nn.functional as F - - -def sample_vectors(samples: Tensor, num: int) -> Tensor: -    """Subsamples a set of vectors.""" -    B, device = samples.shape[0], samples.device -    if B >= num: -        indices = torch.randperm(B, device=device)[:num] -    else: -        indices = torch.randint(0, B, (num,), device=device)[:num] -    return samples[indices] - - -def norm(t: Tensor) -> Tensor: -    """Applies L2-normalization.""" -    return F.normalize(t, p=2, dim=-1) - - -def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None: -    """Applies exponential moving average.""" -    moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) |