diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-27 18:17:42 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-27 18:17:42 +0200 |
commit | 889731f491266a0626861f07fe2a0de1f6caa468 (patch) | |
tree | f627b5c2102036f91fb0c6d20ab622c36412d059 /text_recognizer/networks/vq_transformer.py | |
parent | 68cda1e6abf76312eb13f45ca3c8565a0b00d745 (diff) |
Add vq transformer network
Diffstat (limited to 'text_recognizer/networks/vq_transformer.py')
-rw-r--r-- | text_recognizer/networks/vq_transformer.py | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py new file mode 100644 index 0000000..c12e18b --- /dev/null +++ b/text_recognizer/networks/vq_transformer.py @@ -0,0 +1,57 @@ +from typing import Optional, Tuple, Type + +from torch import nn, Tensor + +from text_recognizer.networks.transformer.decoder import Decoder +from text_recognizer.networks.transformer.embeddings.axial import ( + AxialPositionalEmbedding, +) + +from text_recognizer.networks.conv_transformer import ConvTransformer +from text_recognizer.networks.quantizer.quantizer import VectorQuantizer + + +class VqTransformer(ConvTransformer): + def __init__( + self, + input_dims: Tuple[int, int, int], + hidden_dim: int, + num_classes: int, + pad_index: Tensor, + encoder: Type[nn.Module], + decoder: Decoder, + pixel_embedding: AxialPositionalEmbedding, + token_pos_embedding: Optional[Type[nn.Module]] = None, + quantizer: Optional[VectorQuantizer] = None, + ) -> None: + super().__init__( + input_dims, + hidden_dim, + num_classes, + pad_index, + encoder, + decoder, + pixel_embedding, + token_pos_embedding, + ) + self.quantizer = quantizer + + def quantize(self, z: Tensor) -> Tuple[Tensor, Tensor]: + q, _, loss = self.quantizer(z) + return q, loss + + def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: + z = self.encoder(x) + z = self.conv(z) + q, loss = self.quantize(z) + z = self.pixel_embedding(q) + z = z.flatten(start_dim=2) + + # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] + z = z.permute(0, 2, 1) + return z, loss + + def forward(self, x: Tensor, context: Tensor) -> Tuple[Tensor, Tensor]: + z, loss = self.encode(x) + logits = self.decode(z, context) + return logits, loss |