diff options
Diffstat (limited to 'text_recognizer/networks')
| -rw-r--r-- | text_recognizer/networks/conv_transformer.py | 17 | 
1 files changed, 7 insertions, 10 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 714792e..60c0ef8 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,6 +1,6 @@  """Vision transformer for character recognition."""  import math -from typing import Tuple +from typing import Tuple, Type  from torch import nn, Tensor @@ -24,6 +24,8 @@ class ConvTransformer(nn.Module):          pad_index: Tensor,          encoder: nn.Module,          decoder: Decoder, +        pixel_pos_embedding: Type[nn.Module], +        token_pos_embedding: Type[nn.Module],      ) -> None:          super().__init__()          self.input_dims = input_dims @@ -43,11 +45,7 @@ class ConvTransformer(nn.Module):                  out_channels=self.hidden_dim,                  kernel_size=1,              ), -            PositionalEncoding2D( -                hidden_dim=self.hidden_dim, -                max_h=self.input_dims[1], -                max_w=self.input_dims[2], -            ), +            pixel_pos_embedding,              nn.Flatten(start_dim=2),          ) @@ -57,9 +55,8 @@ class ConvTransformer(nn.Module):          )          # Positional encoding for decoder tokens. -        self.token_pos_encoder = PositionalEncoding( -            hidden_dim=self.hidden_dim, dropout_rate=self.dropout_rate -        ) +        self.token_pos_embedding = token_pos_embedding +          # Head          self.head = nn.Linear(              in_features=self.hidden_dim, out_features=self.num_classes @@ -119,7 +116,7 @@ class ConvTransformer(nn.Module):          trg = trg.long()          trg_mask = trg != self.pad_index          trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim) -        trg = self.token_pos_encoder(trg) +        trg = self.token_pos_embedding(trg)          out = self.decoder(x=trg, context=src, mask=trg_mask)          logits = self.head(out)  # [B, Sy, T]          logits = logits.permute(0, 2, 1)  # [B, T, Sy]  |