diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-18 17:43:44 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-18 17:43:44 +0200 |
commit | 1a193f09bd9199609c30a005dbd28f587dce2606 (patch) | |
tree | d548fcc9c22b00e23512473b1cf8ceddc5121716 /text_recognizer/networks/vq_transformer.py | |
parent | ab47938bfa2d70e6244e8431707586e32f2c7d50 (diff) |
Add Vq transfomer model
Diffstat (limited to 'text_recognizer/networks/vq_transformer.py')
-rw-r--r-- | text_recognizer/networks/vq_transformer.py | 67 |
1 files changed, 55 insertions, 12 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py index 0433863..69f68fd 100644 --- a/text_recognizer/networks/vq_transformer.py +++ b/text_recognizer/networks/vq_transformer.py @@ -1,5 +1,6 @@ """Vector quantized encoder, transformer decoder.""" -from typing import Tuple +from pathlib import Path +from typing import Tuple, Optional import torch from torch import Tensor @@ -22,7 +23,8 @@ class VqTransformer(ConvTransformer): pad_index: Tensor, encoder: VQVAE, decoder: Decoder, - pretrained_encoder_path: str, + no_grad: bool, + pretrained_encoder_path: Optional[str] = None, ) -> None: super().__init__( input_dims=input_dims, @@ -34,18 +36,36 @@ class VqTransformer(ConvTransformer): encoder=encoder, decoder=decoder, ) - self.pretrained_encoder_path = pretrained_encoder_path - # For typing self.encoder: VQVAE - def setup_encoder(self) -> None: + self.no_grad = no_grad + + if pretrained_encoder_path is not None: + self.pretrained_encoder_path = ( + Path(__file__).resolve().parents[2] / pretrained_encoder_path + ) + self._setup_encoder() + else: + self.pretrained_encoder_path = None + + def _load_pretrained_encoder(self) -> None: + self.encoder.load_state_dict( + torch.load(self.pretrained_encoder_path)["state_dict"]["network"] + ) + + def _setup_encoder(self) -> None: """Remove unecessary layers.""" - # TODO: load pretrained vqvae + self._load_pretrained_encoder() del self.encoder.decoder - del self.encoder.post_codebook_conv + # del self.encoder.post_codebook_conv + + def _encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: + z_e = self.encoder.encode(x) + z_q, commitment_loss = self.encoder.quantize(z_e) + return z_q, commitment_loss - def encode(self, x: Tensor) -> Tensor: + def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Encodes an image into a discrete (VQ) latent representation. Args: @@ -62,12 +82,35 @@ class VqTransformer(ConvTransformer): Returns: Tensor: A Latent embedding of the image. """ - with torch.no_grad(): - z_e = self.encoder.encode(x) - z_q, _ = self.encoder.quantize(z_e) + if self.no_grad: + with torch.no_grad(): + z_q, commitment_loss = self._encode(x) + else: + z_q, commitment_loss = self._encode(x) z = self.latent_encoder(z_q) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] z = z.permute(0, 2, 1) - return z + return z, commitment_loss + + def forward(self, x: Tensor, context: Tensor) -> Tensor: + """Encodes images into word piece logtis. + + Args: + x (Tensor): Input image(s). + context (Tensor): Target word embeddings. + + Shapes: + - x: :math: `(B, C, H, W)` + - context: :math: `(B, Sy, T)` + + where B is the batch size, C is the number of input channels, H is + the image height and W is the image width. + + Returns: + Tensor: Sequence of logits. + """ + z, commitment_loss = self.encode(x) + logits = self.decode(z, context) + return logits, commitment_loss |