summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conv_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r--text_recognizer/networks/conv_transformer.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index ddf3b2e..6d54918 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -80,7 +80,7 @@ class ConvTransformer(nn.Module):
"""
z = self.encoder(x)
z = self.conv(z)
- z += self.pixel_embedding(z)
+ z = z + self.pixel_embedding(z)
z = z.flatten(start_dim=2)
# Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
@@ -107,7 +107,7 @@ class ConvTransformer(nn.Module):
trg = trg.long()
trg_mask = trg != self.pad_index
trg = self.token_embedding(trg)
- trg += self.token_pos_embedding(trg)
+ trg = trg + self.token_pos_embedding(trg)
out = self.decoder(x=trg, context=src, input_mask=trg_mask)
logits = (
out @ torch.transpose(self.token_embedding.weight.to(trg.dtype), 0, 1)