summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/quantizer/codebook.py
blob: cb9bc59e2c34563076e71117a4cd23172a9f906f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""Codebook module."""
from typing import Tuple

import attr
from einops import rearrange
import torch
from torch import nn, Tensor
import torch.nn.functional as F

from text_recognizer.networks.quantizer.kmeans import kmeans
from text_recognizer.networks.quantizer.utils import (
    ema_inplace,
    norm,
    sample_vectors,
)


@attr.s(eq=False)
class CosineSimilarityCodebook(nn.Module):
    """Cosine similarity codebook."""

    dim: int = attr.ib()
    codebook_size: int = attr.ib()
    kmeans_init: bool = attr.ib(default=False)
    kmeans_iters: int = attr.ib(default=10)
    decay: float = attr.ib(default=0.8)
    eps: float = attr.ib(default=1.0e-5)
    threshold_dead: int = attr.ib(default=2)

    def __attrs_pre_init__(self) -> None:
        super().__init__()

    def __attrs_post_init__(self) -> None:
        if not self.kmeans_init:
            embeddings = norm(torch.randn(self.codebook_size, self.dim))
        else:
            embeddings = torch.zeros(self.codebook_size, self.dim)
        self.register_buffer("initalized", Tensor([not self.kmeans_init]))
        self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
        self.register_buffer("embeddings", embeddings)

    def _initalize_embedding(self, data: Tensor) -> None:
        embeddings, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
        self.embeddings.data.copy_(embeddings)
        self.cluster_size.data.copy_(cluster_size)
        self.initalized.data.copy_(Tensor([True]))

    def _replace(self, samples: Tensor, mask: Tensor) -> None:
        samples = norm(samples)
        modified_codebook = torch.where(
            mask[..., None],
            sample_vectors(samples, self.codebook_size),
            self.embeddings,
        )
        self.embeddings.data.copy_(modified_codebook)

    def _replace_dead_codes(self, batch_samples: Tensor) -> None:
        if self.threshold_dead == 0:
            return
        dead_codes = self.cluster_size < self.threshold_dead
        if not torch.any(dead_codes):
            return
        batch_samples = rearrange(batch_samples, "... d -> (...) d")
        self._replace(batch_samples, mask=dead_codes)

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        """Quantizes tensor."""
        shape = x.shape
        flatten = rearrange(x, "... d -> (...) d")
        flatten = norm(flatten)

        if not self.initalized:
            self._initalize_embedding(flatten)

        embeddings = norm(self.embeddings)
        dist = flatten @ embeddings.t()
        indices = dist.max(dim=-1).indices
        one_hot = F.one_hot(indices, self.codebook_size).type_as(x)
        indices = indices.view(*shape[:-1])

        quantized = F.embedding(indices, self.embeddings)

        if self.training:
            bins = one_hot.sum(0)
            ema_inplace(self.cluster_size, bins, self.decay)
            zero_mask = bins == 0
            bins = bins.masked_fill(zero_mask, 1.0)

            embed_sum = flatten.t() @ one_hot
            embed_norm = (embed_sum / bins.unsqueeze(0)).t()
            embed_norm = norm(embed_norm)
            embed_norm = torch.where(zero_mask[..., None], embeddings, embed_norm)
            ema_inplace(self.embeddings, embed_norm, self.decay)
            self._replace_dead_codes(x)

        return quantized, indices