diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-13 18:12:13 +0200 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-13 18:12:13 +0200 | 
| commit | 7be90f5f101d7ace7ff07180950dac4c11086ec1 (patch) | |
| tree | a99c0fc55dd45f8e4eda39a958d68863885cfd3f /text_recognizer/networks/conv_transformer.py | |
| parent | 12abf17cd7c31ae4599be366505a4423fbba4044 (diff) | |
Add axial encoder
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
| -rw-r--r-- | text_recognizer/networks/conv_transformer.py | 20 | 
1 files changed, 6 insertions, 14 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 365906f..40047ad 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -5,7 +5,7 @@ from torch import nn, Tensor  from text_recognizer.networks.transformer.decoder import Decoder  from text_recognizer.networks.transformer.embeddings.axial import ( -    AxialPositionalEmbedding, +    AxialPositionalEmbeddingImage,  ) @@ -20,8 +20,8 @@ class ConvTransformer(nn.Module):          pad_index: Tensor,          encoder: Type[nn.Module],          decoder: Decoder, -        pixel_embedding: AxialPositionalEmbedding, -        token_pos_embedding: Optional[Type[nn.Module]] = None, +        pixel_embedding: AxialPositionalEmbeddingImage, +        token_pos_embedding: Type[nn.Module],      ) -> None:          super().__init__()          self.input_dims = input_dims @@ -37,11 +37,7 @@ class ConvTransformer(nn.Module):          )          # Positional encoding for decoder tokens. -        if not self.decoder.has_pos_emb: -            self.token_pos_embedding = token_pos_embedding -        else: -            self.token_pos_embedding = None - +        self.token_pos_embedding = token_pos_embedding          self.pixel_embedding = pixel_embedding          # Latent projector for down sampling number of filters and 2d @@ -83,7 +79,7 @@ class ConvTransformer(nn.Module):          """          z = self.encoder(x)          z = self.conv(z) -        z = self.pixel_embedding(z) +        z += self.pixel_embedding(z)          z = z.flatten(start_dim=2)          # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] @@ -110,11 +106,7 @@ class ConvTransformer(nn.Module):          trg = trg.long()          trg_mask = trg != self.pad_index          trg = self.token_embedding(trg) -        trg = ( -            self.token_pos_embedding(trg) -            if self.token_pos_embedding is not None -            else trg -        ) +        trg += self.token_pos_embedding(trg)          out = self.decoder(x=trg, context=src, input_mask=trg_mask)          logits = self.to_logits(out)  # [B, Sy, C]          logits = logits.permute(0, 2, 1)  # [B, C, Sy]  |