diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 17 |
1 files changed, 8 insertions, 9 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 59ce814..b554695 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -29,17 +29,14 @@ class ConvTransformer(nn.Module): self.pad_index = pad_index self.encoder = encoder self.decoder = decoder + self.pixel_pos_embedding = pixel_pos_embedding # Latent projector for down sampling number of filters and 2d # positional encoding. - self.latent_encoder = nn.Sequential( - nn.Conv2d( - in_channels=self.encoder.out_channels, - out_channels=self.hidden_dim, - kernel_size=1, - ), - pixel_pos_embedding, - nn.Flatten(start_dim=2), + self.conv = nn.Conv2d( + in_channels=self.encoder.out_channels, + out_channels=self.hidden_dim, + kernel_size=1, ) # Token embedding. @@ -87,7 +84,9 @@ class ConvTransformer(nn.Module): Tensor: A Latent embedding of the image. """ z = self.encoder(x) - z = self.latent_encoder(z) + z = self.conv(z) + z = self.pixel_pos_embedding(z) + z = z.flatten(start_dim=2) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] z = z.permute(0, 2, 1) |