From a9363f3944f1ad31590c48d5d51c45df3bbf43b1 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 21 Nov 2021 21:35:34 +0100 Subject: Add axial encoder to conv transformer --- text_recognizer/networks/conv_transformer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index b554695..da99bbf 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,10 +1,14 @@ """Vision transformer for character recognition.""" import math -from typing import Optional, Tuple, Type +from typing import List, Optional, Tuple, Type from loguru import logger as log from torch import nn, Tensor +from text_recognizer.networks.transformer.axial_attention.encoder import AxialEncoder +from text_recognizer.networks.transformer.embeddings.axial import ( + AxialPositionalEmbedding, +) from text_recognizer.networks.transformer.layers import Decoder @@ -19,7 +23,8 @@ class ConvTransformer(nn.Module): pad_index: Tensor, encoder: nn.Module, decoder: Decoder, - pixel_pos_embedding: Type[nn.Module], + axial_encoder: Optional[AxialEncoder], + pixel_pos_embedding: AxialPositionalEmbedding, token_pos_embedding: Optional[Type[nn.Module]] = None, ) -> None: super().__init__() @@ -29,6 +34,7 @@ class ConvTransformer(nn.Module): self.pad_index = pad_index self.encoder = encoder self.decoder = decoder + self.axial_encoder = axial_encoder self.pixel_pos_embedding = pixel_pos_embedding # Latent projector for down sampling number of filters and 2d @@ -86,6 +92,7 @@ class ConvTransformer(nn.Module): z = self.encoder(x) z = self.conv(z) z = self.pixel_pos_embedding(z) + z = self.axial_encoder(z) if self.axial_encoder is not None else z z = z.flatten(start_dim=2) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] -- cgit v1.2.3-70-g09d2