From e586a88ff662ff1841f51c0034679a945d1b79ee Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Tue, 13 Sep 2022 18:47:50 +0200
Subject: Remove vq transformer

---
 text_recognizer/networks/vq_transformer.py | 57 ------------------------------
 1 file changed, 57 deletions(-)
 delete mode 100644 text_recognizer/networks/vq_transformer.py

diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
deleted file mode 100644
index c12e18b..0000000
--- a/text_recognizer/networks/vq_transformer.py
+++ /dev/null
@@ -1,57 +0,0 @@
-from typing import Optional, Tuple, Type
-
-from torch import nn, Tensor
-
-from text_recognizer.networks.transformer.decoder import Decoder
-from text_recognizer.networks.transformer.embeddings.axial import (
-    AxialPositionalEmbedding,
-)
-
-from text_recognizer.networks.conv_transformer import ConvTransformer
-from text_recognizer.networks.quantizer.quantizer import VectorQuantizer
-
-
-class VqTransformer(ConvTransformer):
-    def __init__(
-        self,
-        input_dims: Tuple[int, int, int],
-        hidden_dim: int,
-        num_classes: int,
-        pad_index: Tensor,
-        encoder: Type[nn.Module],
-        decoder: Decoder,
-        pixel_embedding: AxialPositionalEmbedding,
-        token_pos_embedding: Optional[Type[nn.Module]] = None,
-        quantizer: Optional[VectorQuantizer] = None,
-    ) -> None:
-        super().__init__(
-            input_dims,
-            hidden_dim,
-            num_classes,
-            pad_index,
-            encoder,
-            decoder,
-            pixel_embedding,
-            token_pos_embedding,
-        )
-        self.quantizer = quantizer
-
-    def quantize(self, z: Tensor) -> Tuple[Tensor, Tensor]:
-        q, _, loss = self.quantizer(z)
-        return q, loss
-
-    def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
-        z = self.encoder(x)
-        z = self.conv(z)
-        q, loss = self.quantize(z)
-        z = self.pixel_embedding(q)
-        z = z.flatten(start_dim=2)
-
-        # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
-        z = z.permute(0, 2, 1)
-        return z, loss
-
-    def forward(self, x: Tensor, context: Tensor) -> Tuple[Tensor, Tensor]:
-        z, loss = self.encode(x)
-        logits = self.decode(z, context)
-        return logits, loss
-- 
cgit v1.2.3-70-g09d2