summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conv_transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-02-03 21:39:17 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-02-03 21:39:54 +0100
commite6717d5a872e236f90977519a76cb35446ab0d5d (patch)
tree825c6da6b46dea93ec618663a1b7f5ab324998e1 /text_recognizer/networks/conv_transformer.py
parent50fcf0ce31fdad608df26a483d80a83a5738b40f (diff)
chore: remove axial attention
chore: remove axial attention
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r--text_recognizer/networks/conv_transformer.py4
1 files changed, 0 insertions, 4 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index a068ea3..5b29362 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -6,7 +6,6 @@ from loguru import logger as log
from torch import nn, Tensor
from text_recognizer.networks.base import BaseTransformer
-from text_recognizer.networks.transformer.axial_attention.encoder import AxialEncoder
from text_recognizer.networks.transformer.decoder import Decoder
from text_recognizer.networks.transformer.embeddings.axial import (
AxialPositionalEmbedding,
@@ -24,7 +23,6 @@ class ConvTransformer(BaseTransformer):
pad_index: Tensor,
encoder: Type[nn.Module],
decoder: Decoder,
- axial_encoder: Optional[AxialEncoder],
pixel_pos_embedding: AxialPositionalEmbedding,
token_pos_embedding: Optional[Type[nn.Module]] = None,
) -> None:
@@ -39,7 +37,6 @@ class ConvTransformer(BaseTransformer):
)
self.pixel_pos_embedding = pixel_pos_embedding
- self.axial_encoder = axial_encoder
# Latent projector for down sampling number of filters and 2d
# positional encoding.
@@ -79,7 +76,6 @@ class ConvTransformer(BaseTransformer):
z = self.encoder(x)
z = self.conv(z)
z = self.pixel_pos_embedding(z)
- z = self.axial_encoder(z) if self.axial_encoder is not None else z
z = z.flatten(start_dim=2)
# Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]