diff options
Diffstat (limited to 'text_recognizer')
| -rw-r--r-- | text_recognizer/networks/conv_transformer.py | 17 | 
1 files changed, 8 insertions, 9 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 59ce814..b554695 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -29,17 +29,14 @@ class ConvTransformer(nn.Module):          self.pad_index = pad_index          self.encoder = encoder          self.decoder = decoder +        self.pixel_pos_embedding = pixel_pos_embedding          # Latent projector for down sampling number of filters and 2d          # positional encoding. -        self.latent_encoder = nn.Sequential( -            nn.Conv2d( -                in_channels=self.encoder.out_channels, -                out_channels=self.hidden_dim, -                kernel_size=1, -            ), -            pixel_pos_embedding, -            nn.Flatten(start_dim=2), +        self.conv = nn.Conv2d( +            in_channels=self.encoder.out_channels, +            out_channels=self.hidden_dim, +            kernel_size=1,          )          # Token embedding. @@ -87,7 +84,9 @@ class ConvTransformer(nn.Module):              Tensor: A Latent embedding of the image.          """          z = self.encoder(x) -        z = self.latent_encoder(z) +        z = self.conv(z) +        z = self.pixel_pos_embedding(z) +        z = z.flatten(start_dim=2)          # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]          z = z.permute(0, 2, 1)  |