From 3a21c29e2eff4378c63717f8920ca3ccbfef013c Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 3 Oct 2021 00:31:00 +0200 Subject: Lint files --- text_recognizer/networks/conv_transformer.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 714792e..60c0ef8 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,6 +1,6 @@ """Vision transformer for character recognition.""" import math -from typing import Tuple +from typing import Tuple, Type from torch import nn, Tensor @@ -24,6 +24,8 @@ class ConvTransformer(nn.Module): pad_index: Tensor, encoder: nn.Module, decoder: Decoder, + pixel_pos_embedding: Type[nn.Module], + token_pos_embedding: Type[nn.Module], ) -> None: super().__init__() self.input_dims = input_dims @@ -43,11 +45,7 @@ class ConvTransformer(nn.Module): out_channels=self.hidden_dim, kernel_size=1, ), - PositionalEncoding2D( - hidden_dim=self.hidden_dim, - max_h=self.input_dims[1], - max_w=self.input_dims[2], - ), + pixel_pos_embedding, nn.Flatten(start_dim=2), ) @@ -57,9 +55,8 @@ class ConvTransformer(nn.Module): ) # Positional encoding for decoder tokens. - self.token_pos_encoder = PositionalEncoding( - hidden_dim=self.hidden_dim, dropout_rate=self.dropout_rate - ) + self.token_pos_embedding = token_pos_embedding + # Head self.head = nn.Linear( in_features=self.hidden_dim, out_features=self.num_classes @@ -119,7 +116,7 @@ class ConvTransformer(nn.Module): trg = trg.long() trg_mask = trg != self.pad_index trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim) - trg = self.token_pos_encoder(trg) + trg = self.token_pos_embedding(trg) out = self.decoder(x=trg, context=src, mask=trg_mask) logits = self.head(out) # [B, Sy, T] logits = logits.permute(0, 2, 1) # [B, T, Sy] -- cgit v1.2.3-70-g09d2