summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-28 23:41:39 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-28 23:41:39 +0200
commit6547760bfaae0a744ea633eb03cea89775e93c72 (patch)
tree75f908f4001be0e3c80e9cdb7a7de3d553246986
parentdf0c59742803ec4a06afe981467d9f63758714a6 (diff)
Refactor additions
-rw-r--r--text_recognizer/networks/conv_transformer.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index ddf3b2e..6d54918 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -80,7 +80,7 @@ class ConvTransformer(nn.Module):
"""
z = self.encoder(x)
z = self.conv(z)
- z += self.pixel_embedding(z)
+ z = z + self.pixel_embedding(z)
z = z.flatten(start_dim=2)
# Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
@@ -107,7 +107,7 @@ class ConvTransformer(nn.Module):
trg = trg.long()
trg_mask = trg != self.pad_index
trg = self.token_embedding(trg)
- trg += self.token_pos_embedding(trg)
+ trg = trg + self.token_pos_embedding(trg)
out = self.decoder(x=trg, context=src, input_mask=trg_mask)
logits = (
out @ torch.transpose(self.token_embedding.weight.to(trg.dtype), 0, 1)