diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-06 02:42:45 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-06 02:42:45 +0200 |
commit | 3ab82ad36bce6fa698a13a029a0694b75a5947b7 (patch) | |
tree | 136f71a62d60e3ccf01e1f95d64bb4d9f9c9befe /text_recognizer/networks/vq_transformer.py | |
parent | 1bccf71cf4eec335001b50a8fbc0c991d0e6d13a (diff) |
Fix VQVAE into en/decoder, bug in wandb artifact code uploading
Diffstat (limited to 'text_recognizer/networks/vq_transformer.py')
-rw-r--r-- | text_recognizer/networks/vq_transformer.py | 50 |
1 files changed, 23 insertions, 27 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py index a972565..0433863 100644 --- a/text_recognizer/networks/vq_transformer.py +++ b/text_recognizer/networks/vq_transformer.py @@ -1,16 +1,12 @@ """Vector quantized encoder, transformer decoder.""" -import math from typing import Tuple -from torch import nn, Tensor +import torch +from torch import Tensor -from text_recognizer.networks.encoders.efficientnet import EfficientNet +from text_recognizer.networks.vqvae.vqvae import VQVAE from text_recognizer.networks.conv_transformer import ConvTransformer from text_recognizer.networks.transformer.layers import Decoder -from text_recognizer.networks.transformer.positional_encodings import ( - PositionalEncoding, - PositionalEncoding2D, -) class VqTransformer(ConvTransformer): @@ -19,16 +15,18 @@ class VqTransformer(ConvTransformer): def __init__( self, input_dims: Tuple[int, int, int], + encoder_dim: int, hidden_dim: int, dropout_rate: float, num_classes: int, pad_index: Tensor, - encoder: EfficientNet, + encoder: VQVAE, decoder: Decoder, + pretrained_encoder_path: str, ) -> None: - # TODO: Load pretrained vqvae encoder. super().__init__( input_dims=input_dims, + encoder_dim=encoder_dim, hidden_dim=hidden_dim, dropout_rate=dropout_rate, num_classes=num_classes, @@ -36,24 +34,19 @@ class VqTransformer(ConvTransformer): encoder=encoder, decoder=decoder, ) - # Latent projector for down sampling number of filters and 2d - # positional encoding. - self.latent_encoder = nn.Sequential( - nn.Conv2d( - in_channels=self.encoder.out_channels, - out_channels=self.hidden_dim, - kernel_size=1, - ), - PositionalEncoding2D( - hidden_dim=self.hidden_dim, - max_h=self.input_dims[1], - max_w=self.input_dims[2], - ), - nn.Flatten(start_dim=2), - ) + self.pretrained_encoder_path = pretrained_encoder_path + + # For typing + self.encoder: VQVAE + + def setup_encoder(self) -> None: + """Remove unecessary layers.""" + # TODO: load pretrained vqvae + del self.encoder.decoder + del self.encoder.post_codebook_conv def encode(self, x: Tensor) -> Tensor: - """Encodes an image into a latent feature vector. + """Encodes an image into a discrete (VQ) latent representation. Args: x (Tensor): Image tensor. @@ -69,8 +62,11 @@ class VqTransformer(ConvTransformer): Returns: Tensor: A Latent embedding of the image. """ - z = self.encoder(x) - z = self.latent_encoder(z) + with torch.no_grad(): + z_e = self.encoder.encode(x) + z_q, _ = self.encoder.quantize(z_e) + + z = self.latent_encoder(z_q) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] z = z.permute(0, 2, 1) |