summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vq_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/vq_transformer.py')
-rw-r--r--text_recognizer/networks/vq_transformer.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
new file mode 100644
index 0000000..c12e18b
--- /dev/null
+++ b/text_recognizer/networks/vq_transformer.py
@@ -0,0 +1,57 @@
+from typing import Optional, Tuple, Type
+
+from torch import nn, Tensor
+
+from text_recognizer.networks.transformer.decoder import Decoder
+from text_recognizer.networks.transformer.embeddings.axial import (
+ AxialPositionalEmbedding,
+)
+
+from text_recognizer.networks.conv_transformer import ConvTransformer
+from text_recognizer.networks.quantizer.quantizer import VectorQuantizer
+
+
+class VqTransformer(ConvTransformer):
+ def __init__(
+ self,
+ input_dims: Tuple[int, int, int],
+ hidden_dim: int,
+ num_classes: int,
+ pad_index: Tensor,
+ encoder: Type[nn.Module],
+ decoder: Decoder,
+ pixel_embedding: AxialPositionalEmbedding,
+ token_pos_embedding: Optional[Type[nn.Module]] = None,
+ quantizer: Optional[VectorQuantizer] = None,
+ ) -> None:
+ super().__init__(
+ input_dims,
+ hidden_dim,
+ num_classes,
+ pad_index,
+ encoder,
+ decoder,
+ pixel_embedding,
+ token_pos_embedding,
+ )
+ self.quantizer = quantizer
+
+ def quantize(self, z: Tensor) -> Tuple[Tensor, Tensor]:
+ q, _, loss = self.quantizer(z)
+ return q, loss
+
+ def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ z = self.encoder(x)
+ z = self.conv(z)
+ q, loss = self.quantize(z)
+ z = self.pixel_embedding(q)
+ z = z.flatten(start_dim=2)
+
+ # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
+ z = z.permute(0, 2, 1)
+ return z, loss
+
+ def forward(self, x: Tensor, context: Tensor) -> Tuple[Tensor, Tensor]:
+ z, loss = self.encode(x)
+ logits = self.decode(z, context)
+ return logits, loss