diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-02-03 21:39:17 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-02-03 21:39:54 +0100 |
commit | e6717d5a872e236f90977519a76cb35446ab0d5d (patch) | |
tree | 825c6da6b46dea93ec618663a1b7f5ab324998e1 /text_recognizer/networks/conv_transformer.py | |
parent | 50fcf0ce31fdad608df26a483d80a83a5738b40f (diff) |
chore: remove axial attention
chore: remove axial attention
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 4 |
1 files changed, 0 insertions, 4 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index a068ea3..5b29362 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -6,7 +6,6 @@ from loguru import logger as log from torch import nn, Tensor from text_recognizer.networks.base import BaseTransformer -from text_recognizer.networks.transformer.axial_attention.encoder import AxialEncoder from text_recognizer.networks.transformer.decoder import Decoder from text_recognizer.networks.transformer.embeddings.axial import ( AxialPositionalEmbedding, @@ -24,7 +23,6 @@ class ConvTransformer(BaseTransformer): pad_index: Tensor, encoder: Type[nn.Module], decoder: Decoder, - axial_encoder: Optional[AxialEncoder], pixel_pos_embedding: AxialPositionalEmbedding, token_pos_embedding: Optional[Type[nn.Module]] = None, ) -> None: @@ -39,7 +37,6 @@ class ConvTransformer(BaseTransformer): ) self.pixel_pos_embedding = pixel_pos_embedding - self.axial_encoder = axial_encoder # Latent projector for down sampling number of filters and 2d # positional encoding. @@ -79,7 +76,6 @@ class ConvTransformer(BaseTransformer): z = self.encoder(x) z = self.conv(z) z = self.pixel_pos_embedding(z) - z = self.axial_encoder(z) if self.axial_encoder is not None else z z = z.flatten(start_dim=2) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] |