summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/vq_transformer.py57
1 files changed, 0 insertions, 57 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
deleted file mode 100644
index c12e18b..0000000
--- a/text_recognizer/networks/vq_transformer.py
+++ /dev/null
@@ -1,57 +0,0 @@
-from typing import Optional, Tuple, Type
-
-from torch import nn, Tensor
-
-from text_recognizer.networks.transformer.decoder import Decoder
-from text_recognizer.networks.transformer.embeddings.axial import (
- AxialPositionalEmbedding,
-)
-
-from text_recognizer.networks.conv_transformer import ConvTransformer
-from text_recognizer.networks.quantizer.quantizer import VectorQuantizer
-
-
-class VqTransformer(ConvTransformer):
- def __init__(
- self,
- input_dims: Tuple[int, int, int],
- hidden_dim: int,
- num_classes: int,
- pad_index: Tensor,
- encoder: Type[nn.Module],
- decoder: Decoder,
- pixel_embedding: AxialPositionalEmbedding,
- token_pos_embedding: Optional[Type[nn.Module]] = None,
- quantizer: Optional[VectorQuantizer] = None,
- ) -> None:
- super().__init__(
- input_dims,
- hidden_dim,
- num_classes,
- pad_index,
- encoder,
- decoder,
- pixel_embedding,
- token_pos_embedding,
- )
- self.quantizer = quantizer
-
- def quantize(self, z: Tensor) -> Tuple[Tensor, Tensor]:
- q, _, loss = self.quantizer(z)
- return q, loss
-
- def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- z = self.encoder(x)
- z = self.conv(z)
- q, loss = self.quantize(z)
- z = self.pixel_embedding(q)
- z = z.flatten(start_dim=2)
-
- # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
- z = z.permute(0, 2, 1)
- return z, loss
-
- def forward(self, x: Tensor, context: Tensor) -> Tuple[Tensor, Tensor]:
- z, loss = self.encode(x)
- logits = self.decode(z, context)
- return logits, loss