summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vq_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/vq_transformer.py')
-rw-r--r--text_recognizer/networks/vq_transformer.py77
1 files changed, 77 insertions, 0 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
new file mode 100644
index 0000000..a972565
--- /dev/null
+++ b/text_recognizer/networks/vq_transformer.py
@@ -0,0 +1,77 @@
+"""Vector quantized encoder, transformer decoder."""
+import math
+from typing import Tuple
+
+from torch import nn, Tensor
+
+from text_recognizer.networks.encoders.efficientnet import EfficientNet
+from text_recognizer.networks.conv_transformer import ConvTransformer
+from text_recognizer.networks.transformer.layers import Decoder
+from text_recognizer.networks.transformer.positional_encodings import (
+ PositionalEncoding,
+ PositionalEncoding2D,
+)
+
+
+class VqTransformer(ConvTransformer):
+ """Convolutional encoder and transformer decoder network."""
+
+ def __init__(
+ self,
+ input_dims: Tuple[int, int, int],
+ hidden_dim: int,
+ dropout_rate: float,
+ num_classes: int,
+ pad_index: Tensor,
+ encoder: EfficientNet,
+ decoder: Decoder,
+ ) -> None:
+ # TODO: Load pretrained vqvae encoder.
+ super().__init__(
+ input_dims=input_dims,
+ hidden_dim=hidden_dim,
+ dropout_rate=dropout_rate,
+ num_classes=num_classes,
+ pad_index=pad_index,
+ encoder=encoder,
+ decoder=decoder,
+ )
+ # Latent projector for down sampling number of filters and 2d
+ # positional encoding.
+ self.latent_encoder = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.encoder.out_channels,
+ out_channels=self.hidden_dim,
+ kernel_size=1,
+ ),
+ PositionalEncoding2D(
+ hidden_dim=self.hidden_dim,
+ max_h=self.input_dims[1],
+ max_w=self.input_dims[2],
+ ),
+ nn.Flatten(start_dim=2),
+ )
+
+ def encode(self, x: Tensor) -> Tensor:
+ """Encodes an image into a latent feature vector.
+
+ Args:
+ x (Tensor): Image tensor.
+
+ Shape:
+ - x: :math: `(B, C, H, W)`
+ - z: :math: `(B, Sx, E)`
+
+ where Sx is the length of the flattened feature maps projected from
+ the encoder. E latent dimension for each pixel in the projected
+ feature maps.
+
+ Returns:
+ Tensor: A Latent embedding of the image.
+ """
+ z = self.encoder(x)
+ z = self.latent_encoder(z)
+
+ # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
+ z = z.permute(0, 2, 1)
+ return z