diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-17 22:43:54 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-17 22:43:54 +0100 |
commit | 8b06e0ff3185436848c08bf04d730d7e5212e0e5 (patch) | |
tree | 2de819b8eb6be9951f1fb5ba434d6c9c77f146e9 | |
parent | 700ce6ed83867601de0ae55032afdd5e12438258 (diff) |
Update encoder fun in conv_transformer
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 17 |
1 files changed, 8 insertions, 9 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 59ce814..b554695 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -29,17 +29,14 @@ class ConvTransformer(nn.Module): self.pad_index = pad_index self.encoder = encoder self.decoder = decoder + self.pixel_pos_embedding = pixel_pos_embedding # Latent projector for down sampling number of filters and 2d # positional encoding. - self.latent_encoder = nn.Sequential( - nn.Conv2d( - in_channels=self.encoder.out_channels, - out_channels=self.hidden_dim, - kernel_size=1, - ), - pixel_pos_embedding, - nn.Flatten(start_dim=2), + self.conv = nn.Conv2d( + in_channels=self.encoder.out_channels, + out_channels=self.hidden_dim, + kernel_size=1, ) # Token embedding. @@ -87,7 +84,9 @@ class ConvTransformer(nn.Module): Tensor: A Latent embedding of the image. """ z = self.encoder(x) - z = self.latent_encoder(z) + z = self.conv(z) + z = self.pixel_pos_embedding(z) + z = z.flatten(start_dim=2) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] z = z.permute(0, 2, 1) |