summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vq_transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-06 02:42:45 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-06 02:42:45 +0200
commit3ab82ad36bce6fa698a13a029a0694b75a5947b7 (patch)
tree136f71a62d60e3ccf01e1f95d64bb4d9f9c9befe /text_recognizer/networks/vq_transformer.py
parent1bccf71cf4eec335001b50a8fbc0c991d0e6d13a (diff)
Fix VQVAE into en/decoder, bug in wandb artifact code uploading
Diffstat (limited to 'text_recognizer/networks/vq_transformer.py')
-rw-r--r--text_recognizer/networks/vq_transformer.py50
1 files changed, 23 insertions, 27 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
index a972565..0433863 100644
--- a/text_recognizer/networks/vq_transformer.py
+++ b/text_recognizer/networks/vq_transformer.py
@@ -1,16 +1,12 @@
"""Vector quantized encoder, transformer decoder."""
-import math
from typing import Tuple
-from torch import nn, Tensor
+import torch
+from torch import Tensor
-from text_recognizer.networks.encoders.efficientnet import EfficientNet
+from text_recognizer.networks.vqvae.vqvae import VQVAE
from text_recognizer.networks.conv_transformer import ConvTransformer
from text_recognizer.networks.transformer.layers import Decoder
-from text_recognizer.networks.transformer.positional_encodings import (
- PositionalEncoding,
- PositionalEncoding2D,
-)
class VqTransformer(ConvTransformer):
@@ -19,16 +15,18 @@ class VqTransformer(ConvTransformer):
def __init__(
self,
input_dims: Tuple[int, int, int],
+ encoder_dim: int,
hidden_dim: int,
dropout_rate: float,
num_classes: int,
pad_index: Tensor,
- encoder: EfficientNet,
+ encoder: VQVAE,
decoder: Decoder,
+ pretrained_encoder_path: str,
) -> None:
- # TODO: Load pretrained vqvae encoder.
super().__init__(
input_dims=input_dims,
+ encoder_dim=encoder_dim,
hidden_dim=hidden_dim,
dropout_rate=dropout_rate,
num_classes=num_classes,
@@ -36,24 +34,19 @@ class VqTransformer(ConvTransformer):
encoder=encoder,
decoder=decoder,
)
- # Latent projector for down sampling number of filters and 2d
- # positional encoding.
- self.latent_encoder = nn.Sequential(
- nn.Conv2d(
- in_channels=self.encoder.out_channels,
- out_channels=self.hidden_dim,
- kernel_size=1,
- ),
- PositionalEncoding2D(
- hidden_dim=self.hidden_dim,
- max_h=self.input_dims[1],
- max_w=self.input_dims[2],
- ),
- nn.Flatten(start_dim=2),
- )
+ self.pretrained_encoder_path = pretrained_encoder_path
+
+ # For typing
+ self.encoder: VQVAE
+
+ def setup_encoder(self) -> None:
+ """Remove unecessary layers."""
+ # TODO: load pretrained vqvae
+ del self.encoder.decoder
+ del self.encoder.post_codebook_conv
def encode(self, x: Tensor) -> Tensor:
- """Encodes an image into a latent feature vector.
+ """Encodes an image into a discrete (VQ) latent representation.
Args:
x (Tensor): Image tensor.
@@ -69,8 +62,11 @@ class VqTransformer(ConvTransformer):
Returns:
Tensor: A Latent embedding of the image.
"""
- z = self.encoder(x)
- z = self.latent_encoder(z)
+ with torch.no_grad():
+ z_e = self.encoder.encode(x)
+ z_q, _ = self.encoder.quantize(z_e)
+
+ z = self.latent_encoder(z_q)
# Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
z = z.permute(0, 2, 1)