diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-13 18:12:35 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-13 18:12:35 +0200 |
commit | 0901bb8172fe56caa3eba9e4bf96ae0b164f9292 (patch) | |
tree | ad1b5964af91a5982fed59715f058586cd28f60d /text_recognizer/networks/quantizer/quantizer.py | |
parent | 7be90f5f101d7ace7ff07180950dac4c11086ec1 (diff) |
Remove quantizer
Diffstat (limited to 'text_recognizer/networks/quantizer/quantizer.py')
-rw-r--r-- | text_recognizer/networks/quantizer/quantizer.py | 69 |
1 files changed, 0 insertions, 69 deletions
diff --git a/text_recognizer/networks/quantizer/quantizer.py b/text_recognizer/networks/quantizer/quantizer.py deleted file mode 100644 index 2c07b79..0000000 --- a/text_recognizer/networks/quantizer/quantizer.py +++ /dev/null @@ -1,69 +0,0 @@ -from typing import Optional, Tuple, Type - -from einops import rearrange -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - -from text_recognizer.networks.quantizer.utils import orthgonal_loss_fn - - -class VectorQuantizer(nn.Module): - """Vector quantizer.""" - - def __init__( - self, - input_dim: int, - codebook: Type[nn.Module], - commitment: float = 1.0, - ort_reg_weight: float = 0, - ort_reg_max_codes: Optional[int] = None, - ) -> None: - super().__init__() - self.input_dim = input_dim - self.codebook = codebook - self.commitment = commitment - self.ort_reg_weight = ort_reg_weight - self.ort_reg_max_codes = ort_reg_max_codes - require_projection = self.codebook.dim != self.input_dim - self.project_in = ( - nn.Linear(self.input_dim, self.codebook.dim) - if require_projection - else nn.Identity() - ) - self.project_out = ( - nn.Linear(self.codebook.dim, self.input_dim) - if require_projection - else nn.Identity() - ) - - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - """Quantizes latent vectors.""" - H, W = x.shape[-2:] - device = x.device - x = rearrange(x, "b d h w -> b (h w) d") - x = self.project_in(x) - - quantized, indices = self.codebook(x) - - if self.training: - loss = F.mse_loss(quantized.detach(), x) * self.commitment - quantized = x + (quantized - x).detach() - if self.ort_reg_weight > 0: - codebook = self.codebook.embeddings - num_codes = codebook.shape[0] - if num_codes > self.ort_reg_max_codes: - rand_ids = torch.randperm(num_codes, device=device)[ - : self.ort_reg_max_codes - ] - codebook = codebook[rand_ids] - orthgonal_loss = orthgonal_loss_fn(codebook) - loss += self.ort_reg_weight * orthgonal_loss - else: - loss = torch.tensor([0.0]).type_as(x) - - quantized = self.project_out(quantized) - quantized = rearrange(quantized, "b (h w) d -> b d h w", h=H, w=W) - - return quantized, indices, loss |