diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-21 21:34:53 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-21 21:34:53 +0100 |
commit | b44de0e11281c723ec426f8bec8ca0897ecfe3ff (patch) | |
tree | 998841a3a681d3dedfbe8470c1b8544b4dcbe7a2 /text_recognizer/networks/vq_transformer.py | |
parent | 3b2fb0fd977a6aff4dcf88e1a0f99faac51e05b1 (diff) |
Remove VQVAE stuff, did not work...
Diffstat (limited to 'text_recognizer/networks/vq_transformer.py')
-rw-r--r-- | text_recognizer/networks/vq_transformer.py | 84 |
1 files changed, 0 insertions, 84 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py deleted file mode 100644 index a2bd81b..0000000 --- a/text_recognizer/networks/vq_transformer.py +++ /dev/null @@ -1,84 +0,0 @@ -"""Vector quantized encoder, transformer decoder.""" -from typing import Optional, Tuple, Type - -from torch import nn, Tensor - -from text_recognizer.networks.conv_transformer import ConvTransformer -from text_recognizer.networks.quantizer.quantizer import VectorQuantizer -from text_recognizer.networks.transformer.layers import Decoder - - -class VqTransformer(ConvTransformer): - """Convolutional encoder and transformer decoder network.""" - - def __init__( - self, - input_dims: Tuple[int, int, int], - hidden_dim: int, - num_classes: int, - pad_index: Tensor, - encoder: nn.Module, - decoder: Decoder, - pixel_pos_embedding: Type[nn.Module], - quantizer: VectorQuantizer, - token_pos_embedding: Optional[Type[nn.Module]] = None, - ) -> None: - super().__init__( - input_dims=input_dims, - hidden_dim=hidden_dim, - num_classes=num_classes, - pad_index=pad_index, - encoder=encoder, - decoder=decoder, - pixel_pos_embedding=pixel_pos_embedding, - token_pos_embedding=token_pos_embedding, - ) - self.quantizer = quantizer - - def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: - """Encodes an image into a discrete (VQ) latent representation. - - Args: - x (Tensor): Image tensor. - - Shape: - - x: :math: `(B, C, H, W)` - - z: :math: `(B, Sx, E)` - - where Sx is the length of the flattened feature maps projected from - the encoder. E latent dimension for each pixel in the projected - feature maps. - - Returns: - Tensor: A Latent embedding of the image. - """ - z = self.encoder(x) - z = self.conv(z) - z, _, commitment_loss = self.quantizer(z) - z = self.pixel_pos_embedding(z) - z = z.flatten(start_dim=2) - - # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] - z = z.permute(0, 2, 1) - return z, commitment_loss - - def forward(self, x: Tensor, context: Tensor) -> Tensor: - """Encodes images into word piece logtis. - - Args: - x (Tensor): Input image(s). - context (Tensor): Target word embeddings. - - Shapes: - - x: :math: `(B, C, H, W)` - - context: :math: `(B, Sy, T)` - - where B is the batch size, C is the number of input channels, H is - the image height and W is the image width. - - Returns: - Tensor: Sequence of logits. - """ - z, commitment_loss = self.encode(x) - logits = self.decode(z, context) - return logits, commitment_loss |