diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-28 23:41:39 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-28 23:41:39 +0200 |
commit | 6547760bfaae0a744ea633eb03cea89775e93c72 (patch) | |
tree | 75f908f4001be0e3c80e9cdb7a7de3d553246986 /text_recognizer/networks/conv_transformer.py | |
parent | df0c59742803ec4a06afe981467d9f63758714a6 (diff) |
Refactor additions
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index ddf3b2e..6d54918 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -80,7 +80,7 @@ class ConvTransformer(nn.Module): """ z = self.encoder(x) z = self.conv(z) - z += self.pixel_embedding(z) + z = z + self.pixel_embedding(z) z = z.flatten(start_dim=2) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] @@ -107,7 +107,7 @@ class ConvTransformer(nn.Module): trg = trg.long() trg_mask = trg != self.pad_index trg = self.token_embedding(trg) - trg += self.token_pos_embedding(trg) + trg = trg + self.token_pos_embedding(trg) out = self.decoder(x=trg, context=src, input_mask=trg_mask) logits = ( out @ torch.transpose(self.token_embedding.weight.to(trg.dtype), 0, 1) |