From e586a88ff662ff1841f51c0034679a945d1b79ee Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 13 Sep 2022 18:47:50 +0200 Subject: Remove vq transformer --- text_recognizer/networks/vq_transformer.py | 57 ------------------------------ 1 file changed, 57 deletions(-) delete mode 100644 text_recognizer/networks/vq_transformer.py (limited to 'text_recognizer/networks/vq_transformer.py') diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py deleted file mode 100644 index c12e18b..0000000 --- a/text_recognizer/networks/vq_transformer.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Optional, Tuple, Type - -from torch import nn, Tensor - -from text_recognizer.networks.transformer.decoder import Decoder -from text_recognizer.networks.transformer.embeddings.axial import ( - AxialPositionalEmbedding, -) - -from text_recognizer.networks.conv_transformer import ConvTransformer -from text_recognizer.networks.quantizer.quantizer import VectorQuantizer - - -class VqTransformer(ConvTransformer): - def __init__( - self, - input_dims: Tuple[int, int, int], - hidden_dim: int, - num_classes: int, - pad_index: Tensor, - encoder: Type[nn.Module], - decoder: Decoder, - pixel_embedding: AxialPositionalEmbedding, - token_pos_embedding: Optional[Type[nn.Module]] = None, - quantizer: Optional[VectorQuantizer] = None, - ) -> None: - super().__init__( - input_dims, - hidden_dim, - num_classes, - pad_index, - encoder, - decoder, - pixel_embedding, - token_pos_embedding, - ) - self.quantizer = quantizer - - def quantize(self, z: Tensor) -> Tuple[Tensor, Tensor]: - q, _, loss = self.quantizer(z) - return q, loss - - def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: - z = self.encoder(x) - z = self.conv(z) - q, loss = self.quantize(z) - z = self.pixel_embedding(q) - z = z.flatten(start_dim=2) - - # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] - z = z.permute(0, 2, 1) - return z, loss - - def forward(self, x: Tensor, context: Tensor) -> Tuple[Tensor, Tensor]: - z, loss = self.encode(x) - logits = self.decode(z, context) - return logits, loss -- cgit v1.2.3-70-g09d2