diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-04 05:03:51 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-04 05:03:51 +0200 |
commit | d3afa310f77f47553586eeee58e3d3345a754e2c (patch) | |
tree | 08b7de1daf2550852d0a1e4d4d75202f14bb03d4 /text_recognizer/networks/vq_transformer.py | |
parent | 65d5f6c694e73792e40ed693a1381a792da8d277 (diff) |
New VQVAE
Diffstat (limited to 'text_recognizer/networks/vq_transformer.py')
-rw-r--r-- | text_recognizer/networks/vq_transformer.py | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py new file mode 100644 index 0000000..a972565 --- /dev/null +++ b/text_recognizer/networks/vq_transformer.py @@ -0,0 +1,77 @@ +"""Vector quantized encoder, transformer decoder.""" +import math +from typing import Tuple + +from torch import nn, Tensor + +from text_recognizer.networks.encoders.efficientnet import EfficientNet +from text_recognizer.networks.conv_transformer import ConvTransformer +from text_recognizer.networks.transformer.layers import Decoder +from text_recognizer.networks.transformer.positional_encodings import ( + PositionalEncoding, + PositionalEncoding2D, +) + + +class VqTransformer(ConvTransformer): + """Convolutional encoder and transformer decoder network.""" + + def __init__( + self, + input_dims: Tuple[int, int, int], + hidden_dim: int, + dropout_rate: float, + num_classes: int, + pad_index: Tensor, + encoder: EfficientNet, + decoder: Decoder, + ) -> None: + # TODO: Load pretrained vqvae encoder. + super().__init__( + input_dims=input_dims, + hidden_dim=hidden_dim, + dropout_rate=dropout_rate, + num_classes=num_classes, + pad_index=pad_index, + encoder=encoder, + decoder=decoder, + ) + # Latent projector for down sampling number of filters and 2d + # positional encoding. + self.latent_encoder = nn.Sequential( + nn.Conv2d( + in_channels=self.encoder.out_channels, + out_channels=self.hidden_dim, + kernel_size=1, + ), + PositionalEncoding2D( + hidden_dim=self.hidden_dim, + max_h=self.input_dims[1], + max_w=self.input_dims[2], + ), + nn.Flatten(start_dim=2), + ) + + def encode(self, x: Tensor) -> Tensor: + """Encodes an image into a latent feature vector. + + Args: + x (Tensor): Image tensor. + + Shape: + - x: :math: `(B, C, H, W)` + - z: :math: `(B, Sx, E)` + + where Sx is the length of the flattened feature maps projected from + the encoder. E latent dimension for each pixel in the projected + feature maps. + + Returns: + Tensor: A Latent embedding of the image. + """ + z = self.encoder(x) + z = self.latent_encoder(z) + + # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] + z = z.permute(0, 2, 1) + return z |