diff options
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 4 |
1 files changed, 0 insertions, 4 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index a068ea3..5b29362 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -6,7 +6,6 @@ from loguru import logger as log from torch import nn, Tensor from text_recognizer.networks.base import BaseTransformer -from text_recognizer.networks.transformer.axial_attention.encoder import AxialEncoder from text_recognizer.networks.transformer.decoder import Decoder from text_recognizer.networks.transformer.embeddings.axial import ( AxialPositionalEmbedding, @@ -24,7 +23,6 @@ class ConvTransformer(BaseTransformer): pad_index: Tensor, encoder: Type[nn.Module], decoder: Decoder, - axial_encoder: Optional[AxialEncoder], pixel_pos_embedding: AxialPositionalEmbedding, token_pos_embedding: Optional[Type[nn.Module]] = None, ) -> None: @@ -39,7 +37,6 @@ class ConvTransformer(BaseTransformer): ) self.pixel_pos_embedding = pixel_pos_embedding - self.axial_encoder = axial_encoder # Latent projector for down sampling number of filters and 2d # positional encoding. @@ -79,7 +76,6 @@ class ConvTransformer(BaseTransformer): z = self.encoder(x) z = self.conv(z) z = self.pixel_pos_embedding(z) - z = self.axial_encoder(z) if self.axial_encoder is not None else z z = z.flatten(start_dim=2) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] |