From 7be90f5f101d7ace7ff07180950dac4c11086ec1 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 13 Sep 2022 18:12:13 +0200 Subject: Add axial encoder --- text_recognizer/networks/conv_transformer.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) (limited to 'text_recognizer/networks/conv_transformer.py') diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 365906f..40047ad 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -5,7 +5,7 @@ from torch import nn, Tensor from text_recognizer.networks.transformer.decoder import Decoder from text_recognizer.networks.transformer.embeddings.axial import ( - AxialPositionalEmbedding, + AxialPositionalEmbeddingImage, ) @@ -20,8 +20,8 @@ class ConvTransformer(nn.Module): pad_index: Tensor, encoder: Type[nn.Module], decoder: Decoder, - pixel_embedding: AxialPositionalEmbedding, - token_pos_embedding: Optional[Type[nn.Module]] = None, + pixel_embedding: AxialPositionalEmbeddingImage, + token_pos_embedding: Type[nn.Module], ) -> None: super().__init__() self.input_dims = input_dims @@ -37,11 +37,7 @@ class ConvTransformer(nn.Module): ) # Positional encoding for decoder tokens. - if not self.decoder.has_pos_emb: - self.token_pos_embedding = token_pos_embedding - else: - self.token_pos_embedding = None - + self.token_pos_embedding = token_pos_embedding self.pixel_embedding = pixel_embedding # Latent projector for down sampling number of filters and 2d @@ -83,7 +79,7 @@ class ConvTransformer(nn.Module): """ z = self.encoder(x) z = self.conv(z) - z = self.pixel_embedding(z) + z += self.pixel_embedding(z) z = z.flatten(start_dim=2) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] @@ -110,11 +106,7 @@ class ConvTransformer(nn.Module): trg = trg.long() trg_mask = trg != self.pad_index trg = self.token_embedding(trg) - trg = ( - self.token_pos_embedding(trg) - if self.token_pos_embedding is not None - else trg - ) + trg += self.token_pos_embedding(trg) out = self.decoder(x=trg, context=src, input_mask=trg_mask) logits = self.to_logits(out) # [B, Sy, C] logits = logits.permute(0, 2, 1) # [B, C, Sy] -- cgit v1.2.3-70-g09d2