diff options
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index ddf3b2e..6d54918 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -80,7 +80,7 @@ class ConvTransformer(nn.Module): """ z = self.encoder(x) z = self.conv(z) - z += self.pixel_embedding(z) + z = z + self.pixel_embedding(z) z = z.flatten(start_dim=2) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] @@ -107,7 +107,7 @@ class ConvTransformer(nn.Module): trg = trg.long() trg_mask = trg != self.pad_index trg = self.token_embedding(trg) - trg += self.token_pos_embedding(trg) + trg = trg + self.token_pos_embedding(trg) out = self.decoder(x=trg, context=src, input_mask=trg_mask) logits = ( out @ torch.transpose(self.token_embedding.weight.to(trg.dtype), 0, 1) |