summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conv_transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-13 18:12:13 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-13 18:12:13 +0200
commit7be90f5f101d7ace7ff07180950dac4c11086ec1 (patch)
treea99c0fc55dd45f8e4eda39a958d68863885cfd3f /text_recognizer/networks/conv_transformer.py
parent12abf17cd7c31ae4599be366505a4423fbba4044 (diff)
Add axial encoder
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r--text_recognizer/networks/conv_transformer.py20
1 files changed, 6 insertions, 14 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 365906f..40047ad 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -5,7 +5,7 @@ from torch import nn, Tensor
from text_recognizer.networks.transformer.decoder import Decoder
from text_recognizer.networks.transformer.embeddings.axial import (
- AxialPositionalEmbedding,
+ AxialPositionalEmbeddingImage,
)
@@ -20,8 +20,8 @@ class ConvTransformer(nn.Module):
pad_index: Tensor,
encoder: Type[nn.Module],
decoder: Decoder,
- pixel_embedding: AxialPositionalEmbedding,
- token_pos_embedding: Optional[Type[nn.Module]] = None,
+ pixel_embedding: AxialPositionalEmbeddingImage,
+ token_pos_embedding: Type[nn.Module],
) -> None:
super().__init__()
self.input_dims = input_dims
@@ -37,11 +37,7 @@ class ConvTransformer(nn.Module):
)
# Positional encoding for decoder tokens.
- if not self.decoder.has_pos_emb:
- self.token_pos_embedding = token_pos_embedding
- else:
- self.token_pos_embedding = None
-
+ self.token_pos_embedding = token_pos_embedding
self.pixel_embedding = pixel_embedding
# Latent projector for down sampling number of filters and 2d
@@ -83,7 +79,7 @@ class ConvTransformer(nn.Module):
"""
z = self.encoder(x)
z = self.conv(z)
- z = self.pixel_embedding(z)
+ z += self.pixel_embedding(z)
z = z.flatten(start_dim=2)
# Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
@@ -110,11 +106,7 @@ class ConvTransformer(nn.Module):
trg = trg.long()
trg_mask = trg != self.pad_index
trg = self.token_embedding(trg)
- trg = (
- self.token_pos_embedding(trg)
- if self.token_pos_embedding is not None
- else trg
- )
+ trg += self.token_pos_embedding(trg)
out = self.decoder(x=trg, context=src, input_mask=trg_mask)
logits = self.to_logits(out) # [B, Sy, C]
logits = logits.permute(0, 2, 1) # [B, C, Sy]