diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-17 22:44:10 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-17 22:44:10 +0100 |
commit | a9a7c8ab3314e36eed7781552bf1b383a624eea2 (patch) | |
tree | 89cd4cb6abf8370150a94d9fb1a8ca2faa7345b8 /text_recognizer/networks | |
parent | 8b06e0ff3185436848c08bf04d730d7e5212e0e5 (diff) |
Update VqTransformer
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/vq_transformer.py | 69 |
1 files changed, 16 insertions, 53 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py index 1c85bf9..a2bd81b 100644 --- a/text_recognizer/networks/vq_transformer.py +++ b/text_recognizer/networks/vq_transformer.py @@ -1,14 +1,10 @@ """Vector quantized encoder, transformer decoder.""" -from pathlib import Path -from typing import OrderedDict, Tuple +from typing import Optional, Tuple, Type -from omegaconf import OmegaConf -from hydra.utils import instantiate -import torch -from torch import Tensor +from torch import nn, Tensor -from text_recognizer.networks.vqvae.vqvae import VQVAE from text_recognizer.networks.conv_transformer import ConvTransformer +from text_recognizer.networks.quantizer.quantizer import VectorQuantizer from text_recognizer.networks.transformer.layers import Decoder @@ -18,57 +14,26 @@ class VqTransformer(ConvTransformer): def __init__( self, input_dims: Tuple[int, int, int], - encoder_dim: int, hidden_dim: int, - dropout_rate: float, num_classes: int, pad_index: Tensor, + encoder: nn.Module, decoder: Decoder, - no_grad: bool, - pretrained_encoder_path: str, + pixel_pos_embedding: Type[nn.Module], + quantizer: VectorQuantizer, + token_pos_embedding: Optional[Type[nn.Module]] = None, ) -> None: - # For typing - self.encoder: VQVAE = None - self.no_grad = no_grad - super().__init__( input_dims=input_dims, - encoder_dim=encoder_dim, hidden_dim=hidden_dim, - dropout_rate=dropout_rate, num_classes=num_classes, pad_index=pad_index, - encoder=self.encoder, + encoder=encoder, decoder=decoder, + pixel_pos_embedding=pixel_pos_embedding, + token_pos_embedding=token_pos_embedding, ) - self._setup_encoder(pretrained_encoder_path) - - def _load_state_dict(self, path: Path) -> OrderedDict: - weights_path = list((path / "checkpoints").glob("epoch=*.ckpt"))[0] - renamed_state_dict = OrderedDict() - state_dict = torch.load(weights_path)["state_dict"] - for key in state_dict.keys(): - if "network" in key: - new_key = key.removeprefix("network.") - renamed_state_dict[new_key] = state_dict[key] - del state_dict - return renamed_state_dict - - def _setup_encoder(self, pretrained_encoder_path: str,) -> None: - """Load encoder module.""" - path = Path(__file__).resolve().parents[2] / pretrained_encoder_path - with open(path / "config.yaml") as f: - cfg = OmegaConf.load(f) - state_dict = self._load_state_dict(path) - self.encoder = instantiate(cfg.network) - self.encoder.load_state_dict(state_dict) - del self.encoder.decoder - - def _encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: - z_e = self.encoder.encode(x) - z_q, commitment_loss = self.encoder.quantize(z_e) - z = self.encoder.post_codebook_conv(z_q) - return z, commitment_loss + self.quantizer = quantizer def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Encodes an image into a discrete (VQ) latent representation. @@ -87,13 +52,11 @@ class VqTransformer(ConvTransformer): Returns: Tensor: A Latent embedding of the image. """ - if self.no_grad: - with torch.no_grad(): - z_q, commitment_loss = self._encode(x) - else: - z_q, commitment_loss = self._encode(x) - - z = self.latent_encoder(z_q) + z = self.encoder(x) + z = self.conv(z) + z, _, commitment_loss = self.quantizer(z) + z = self.pixel_pos_embedding(z) + z = z.flatten(start_dim=2) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] z = z.permute(0, 2, 1) |