summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/quantizer/quantizer.py
blob: 3e8f0b2e0a581226f60db0771afc2c3b0b644b40 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Implementation of a Vector Quantized Variational AutoEncoder.

Reference:
https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
"""
from typing import Tuple, Type

import attr
from einops import rearrange
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F


@attr.s(eq=False)
class VectorQuantizer(nn.Module):
    """Vector quantizer."""

    input_dim: int = attr.ib()
    codebook: Type[nn.Module] = attr.ib()
    commitment: float = attr.ib(default=1.0)
    project_in: nn.Linear = attr.ib(default=None, init=False)
    project_out: nn.Linear = attr.ib(default=None, init=False)

    def __attrs_pre_init__(self) -> None:
        super().__init__()

    def __attrs_post_init__(self) -> None:
        require_projection = self.codebook.dim != self.input_dim
        self.project_in = (
            nn.Linear(self.input_dim, self.codebook.dim)
            if require_projection
            else nn.Identity()
        )
        self.project_out = (
            nn.Linear(self.codebook.dim, self.input_dim)
            if require_projection
            else nn.Identity()
        )

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        """Quantizes latent vectors."""
        H, W = x.shape[-2:]
        x = rearrange(x, "b d h w -> b (h w) d")
        x = self.project_in(x)

        quantized, indices = self.codebook(x)

        if self.training:
            commitment_loss = F.mse_loss(quantized.detach(), x) * self.commitment
            quantized = x + (quantized - x).detach()
        else:
            commitment_loss = torch.tensor([0.0]).type_as(x)

        quantized = self.project_out(quantized)
        quantized = rearrange(quantized, "b (h w) d -> b d h w", h=H, w=W)

        return quantized, indices, commitment_loss