From 6547760bfaae0a744ea633eb03cea89775e93c72 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 28 Sep 2022 23:41:39 +0200 Subject: Refactor additions --- text_recognizer/networks/conv_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'text_recognizer') diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index ddf3b2e..6d54918 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -80,7 +80,7 @@ class ConvTransformer(nn.Module): """ z = self.encoder(x) z = self.conv(z) - z += self.pixel_embedding(z) + z = z + self.pixel_embedding(z) z = z.flatten(start_dim=2) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] @@ -107,7 +107,7 @@ class ConvTransformer(nn.Module): trg = trg.long() trg_mask = trg != self.pad_index trg = self.token_embedding(trg) - trg += self.token_pos_embedding(trg) + trg = trg + self.token_pos_embedding(trg) out = self.decoder(x=trg, context=src, input_mask=trg_mask) logits = ( out @ torch.transpose(self.token_embedding.weight.to(trg.dtype), 0, 1) -- cgit v1.2.3-70-g09d2