summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vq_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/vq_transformer.py')
-rw-r--r--text_recognizer/networks/vq_transformer.py84
1 files changed, 0 insertions, 84 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
deleted file mode 100644
index a2bd81b..0000000
--- a/text_recognizer/networks/vq_transformer.py
+++ /dev/null
@@ -1,84 +0,0 @@
-"""Vector quantized encoder, transformer decoder."""
-from typing import Optional, Tuple, Type
-
-from torch import nn, Tensor
-
-from text_recognizer.networks.conv_transformer import ConvTransformer
-from text_recognizer.networks.quantizer.quantizer import VectorQuantizer
-from text_recognizer.networks.transformer.layers import Decoder
-
-
-class VqTransformer(ConvTransformer):
- """Convolutional encoder and transformer decoder network."""
-
- def __init__(
- self,
- input_dims: Tuple[int, int, int],
- hidden_dim: int,
- num_classes: int,
- pad_index: Tensor,
- encoder: nn.Module,
- decoder: Decoder,
- pixel_pos_embedding: Type[nn.Module],
- quantizer: VectorQuantizer,
- token_pos_embedding: Optional[Type[nn.Module]] = None,
- ) -> None:
- super().__init__(
- input_dims=input_dims,
- hidden_dim=hidden_dim,
- num_classes=num_classes,
- pad_index=pad_index,
- encoder=encoder,
- decoder=decoder,
- pixel_pos_embedding=pixel_pos_embedding,
- token_pos_embedding=token_pos_embedding,
- )
- self.quantizer = quantizer
-
- def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes an image into a discrete (VQ) latent representation.
-
- Args:
- x (Tensor): Image tensor.
-
- Shape:
- - x: :math: `(B, C, H, W)`
- - z: :math: `(B, Sx, E)`
-
- where Sx is the length of the flattened feature maps projected from
- the encoder. E latent dimension for each pixel in the projected
- feature maps.
-
- Returns:
- Tensor: A Latent embedding of the image.
- """
- z = self.encoder(x)
- z = self.conv(z)
- z, _, commitment_loss = self.quantizer(z)
- z = self.pixel_pos_embedding(z)
- z = z.flatten(start_dim=2)
-
- # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
- z = z.permute(0, 2, 1)
- return z, commitment_loss
-
- def forward(self, x: Tensor, context: Tensor) -> Tensor:
- """Encodes images into word piece logtis.
-
- Args:
- x (Tensor): Input image(s).
- context (Tensor): Target word embeddings.
-
- Shapes:
- - x: :math: `(B, C, H, W)`
- - context: :math: `(B, Sy, T)`
-
- where B is the batch size, C is the number of input channels, H is
- the image height and W is the image width.
-
- Returns:
- Tensor: Sequence of logits.
- """
- z, commitment_loss = self.encode(x)
- logits = self.decode(z, context)
- return logits, commitment_loss