summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/quantizer/cosine_codebook.py
blob: 3b6af0f9c861df7b0bd4d6c212d0da8ba180ddfb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""Codebook module."""
from typing import Tuple

from einops import rearrange
import torch
from torch import nn, Tensor
import torch.nn.functional as F

from text_recognizer.networks.quantizer.kmeans import kmeans
from text_recognizer.networks.quantizer.utils import (
    ema_inplace,
    norm,
    sample_vectors,
    gumbel_sample,
)


class CosineSimilarityCodebook(nn.Module):
    """Cosine similarity codebook."""

    def __init__(
        self,
        dim: int,
        codebook_size: int,
        kmeans_init: bool = False,
        kmeans_iters: int = 10,
        decay: float = 0.8,
        eps: float = 1.0e-5,
        threshold_dead: int = 2,
        temperature: float = 0.0,
    ) -> None:
        super().__init__()
        self.dim = dim
        self.codebook_size = codebook_size
        self.kmeans_init = kmeans_init
        self.kmeans_iters = kmeans_iters
        self.decay = decay
        self.eps = eps
        self.threshold_dead = threshold_dead
        self.temperature = temperature

        if not self.kmeans_init:
            embeddings = norm(torch.randn(self.codebook_size, self.dim))
        else:
            embeddings = torch.zeros(self.codebook_size, self.dim)
        self.register_buffer("initalized", Tensor([not self.kmeans_init]))
        self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
        self.register_buffer("embeddings", embeddings)

    def _initalize_embedding(self, data: Tensor) -> None:
        embeddings, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
        self.embeddings.data.copy_(embeddings)
        self.cluster_size.data.copy_(cluster_size)
        self.initalized.data.copy_(Tensor([True]))

    def _replace(self, samples: Tensor, mask: Tensor) -> None:
        samples = norm(samples)
        modified_codebook = torch.where(
            mask[..., None],
            sample_vectors(samples, self.codebook_size),
            self.embeddings,
        )
        self.embeddings.data.copy_(modified_codebook)

    def _replace_dead_codes(self, batch_samples: Tensor) -> None:
        if self.threshold_dead == 0:
            return
        dead_codes = self.cluster_size < self.threshold_dead
        if not torch.any(dead_codes):
            return
        batch_samples = rearrange(batch_samples, "... d -> (...) d")
        self._replace(batch_samples, mask=dead_codes)

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        """Quantizes tensor."""
        shape = x.shape
        flatten = rearrange(x, "... d -> (...) d")
        flatten = norm(flatten)

        if not self.initalized:
            self._initalize_embedding(flatten)

        embeddings = norm(self.embeddings)
        dist = flatten @ embeddings.t()
        indices = gumbel_sample(dist, dim=-1, temperature=self.temperature)
        one_hot = F.one_hot(indices, self.codebook_size).type_as(x)
        indices = indices.view(*shape[:-1])

        quantized = F.embedding(indices, self.embeddings)

        if self.training:
            bins = one_hot.sum(0)
            ema_inplace(self.cluster_size, bins, self.decay)
            zero_mask = bins == 0
            bins = bins.masked_fill(zero_mask, 1.0)

            embed_sum = flatten.t() @ one_hot
            embed_norm = (embed_sum / bins.unsqueeze(0)).t()
            embed_norm = norm(embed_norm)
            embed_norm = torch.where(zero_mask[..., None], embeddings, embed_norm)
            ema_inplace(self.embeddings, embed_norm, self.decay)
            self._replace_dead_codes(x)

        return quantized, indices