diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 17 |
1 files changed, 7 insertions, 10 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 714792e..60c0ef8 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,6 +1,6 @@ """Vision transformer for character recognition.""" import math -from typing import Tuple +from typing import Tuple, Type from torch import nn, Tensor @@ -24,6 +24,8 @@ class ConvTransformer(nn.Module): pad_index: Tensor, encoder: nn.Module, decoder: Decoder, + pixel_pos_embedding: Type[nn.Module], + token_pos_embedding: Type[nn.Module], ) -> None: super().__init__() self.input_dims = input_dims @@ -43,11 +45,7 @@ class ConvTransformer(nn.Module): out_channels=self.hidden_dim, kernel_size=1, ), - PositionalEncoding2D( - hidden_dim=self.hidden_dim, - max_h=self.input_dims[1], - max_w=self.input_dims[2], - ), + pixel_pos_embedding, nn.Flatten(start_dim=2), ) @@ -57,9 +55,8 @@ class ConvTransformer(nn.Module): ) # Positional encoding for decoder tokens. - self.token_pos_encoder = PositionalEncoding( - hidden_dim=self.hidden_dim, dropout_rate=self.dropout_rate - ) + self.token_pos_embedding = token_pos_embedding + # Head self.head = nn.Linear( in_features=self.hidden_dim, out_features=self.num_classes @@ -119,7 +116,7 @@ class ConvTransformer(nn.Module): trg = trg.long() trg_mask = trg != self.pad_index trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim) - trg = self.token_pos_encoder(trg) + trg = self.token_pos_embedding(trg) out = self.decoder(x=trg, context=src, mask=trg_mask) logits = self.head(out) # [B, Sy, T] logits = logits.permute(0, 2, 1) # [B, T, Sy] |