diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-27 18:17:16 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-27 18:17:16 +0200 |
commit | 68cda1e6abf76312eb13f45ca3c8565a0b00d745 (patch) | |
tree | 4a3a26acc70901675ac92b1f9e2b71685b107864 /text_recognizer/networks/quantizer/quantizer.py | |
parent | aeadb9d82c577879ab8110eb20a9a12d6ca6750c (diff) |
Add quantizer
Diffstat (limited to 'text_recognizer/networks/quantizer/quantizer.py')
-rw-r--r-- | text_recognizer/networks/quantizer/quantizer.py | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/text_recognizer/networks/quantizer/quantizer.py b/text_recognizer/networks/quantizer/quantizer.py new file mode 100644 index 0000000..2c07b79 --- /dev/null +++ b/text_recognizer/networks/quantizer/quantizer.py @@ -0,0 +1,69 @@ +from typing import Optional, Tuple, Type + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F + +from text_recognizer.networks.quantizer.utils import orthgonal_loss_fn + + +class VectorQuantizer(nn.Module): + """Vector quantizer.""" + + def __init__( + self, + input_dim: int, + codebook: Type[nn.Module], + commitment: float = 1.0, + ort_reg_weight: float = 0, + ort_reg_max_codes: Optional[int] = None, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.codebook = codebook + self.commitment = commitment + self.ort_reg_weight = ort_reg_weight + self.ort_reg_max_codes = ort_reg_max_codes + require_projection = self.codebook.dim != self.input_dim + self.project_in = ( + nn.Linear(self.input_dim, self.codebook.dim) + if require_projection + else nn.Identity() + ) + self.project_out = ( + nn.Linear(self.codebook.dim, self.input_dim) + if require_projection + else nn.Identity() + ) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """Quantizes latent vectors.""" + H, W = x.shape[-2:] + device = x.device + x = rearrange(x, "b d h w -> b (h w) d") + x = self.project_in(x) + + quantized, indices = self.codebook(x) + + if self.training: + loss = F.mse_loss(quantized.detach(), x) * self.commitment + quantized = x + (quantized - x).detach() + if self.ort_reg_weight > 0: + codebook = self.codebook.embeddings + num_codes = codebook.shape[0] + if num_codes > self.ort_reg_max_codes: + rand_ids = torch.randperm(num_codes, device=device)[ + : self.ort_reg_max_codes + ] + codebook = codebook[rand_ids] + orthgonal_loss = orthgonal_loss_fn(codebook) + loss += self.ort_reg_weight * orthgonal_loss + else: + loss = torch.tensor([0.0]).type_as(x) + + quantized = self.project_out(quantized) + quantized = rearrange(quantized, "b (h w) d -> b d h w", h=H, w=W) + + return quantized, indices, loss |