From 8b06e0ff3185436848c08bf04d730d7e5212e0e5 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 17 Nov 2021 22:43:54 +0100 Subject: Update encoder fun in conv_transformer --- text_recognizer/networks/conv_transformer.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 59ce814..b554695 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -29,17 +29,14 @@ class ConvTransformer(nn.Module): self.pad_index = pad_index self.encoder = encoder self.decoder = decoder + self.pixel_pos_embedding = pixel_pos_embedding # Latent projector for down sampling number of filters and 2d # positional encoding. - self.latent_encoder = nn.Sequential( - nn.Conv2d( - in_channels=self.encoder.out_channels, - out_channels=self.hidden_dim, - kernel_size=1, - ), - pixel_pos_embedding, - nn.Flatten(start_dim=2), + self.conv = nn.Conv2d( + in_channels=self.encoder.out_channels, + out_channels=self.hidden_dim, + kernel_size=1, ) # Token embedding. @@ -87,7 +84,9 @@ class ConvTransformer(nn.Module): Tensor: A Latent embedding of the image. """ z = self.encoder(x) - z = self.latent_encoder(z) + z = self.conv(z) + z = self.pixel_pos_embedding(z) + z = z.flatten(start_dim=2) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] z = z.permute(0, 2, 1) -- cgit v1.2.3-70-g09d2