summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/conv_transformer.py17
1 files changed, 8 insertions, 9 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 59ce814..b554695 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -29,17 +29,14 @@ class ConvTransformer(nn.Module):
self.pad_index = pad_index
self.encoder = encoder
self.decoder = decoder
+ self.pixel_pos_embedding = pixel_pos_embedding
# Latent projector for down sampling number of filters and 2d
# positional encoding.
- self.latent_encoder = nn.Sequential(
- nn.Conv2d(
- in_channels=self.encoder.out_channels,
- out_channels=self.hidden_dim,
- kernel_size=1,
- ),
- pixel_pos_embedding,
- nn.Flatten(start_dim=2),
+ self.conv = nn.Conv2d(
+ in_channels=self.encoder.out_channels,
+ out_channels=self.hidden_dim,
+ kernel_size=1,
)
# Token embedding.
@@ -87,7 +84,9 @@ class ConvTransformer(nn.Module):
Tensor: A Latent embedding of the image.
"""
z = self.encoder(x)
- z = self.latent_encoder(z)
+ z = self.conv(z)
+ z = self.pixel_pos_embedding(z)
+ z = z.flatten(start_dim=2)
# Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
z = z.permute(0, 2, 1)