summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/cnn_transformer.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2021-01-07 23:35:42 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2021-01-07 23:35:42 +0100
commitd691b548cd0b6fc4ea184d64261f633789fee021 (patch)
tree99e2fc5481ce102d5655b65681274e5f0286306f /src/text_recognizer/networks/cnn_transformer.py
parentff9a21d333f11a42e67c1963ed67de9c0fda87c9 (diff)
working on vq-vae
Diffstat (limited to 'src/text_recognizer/networks/cnn_transformer.py')
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py12
1 files changed, 0 insertions, 12 deletions
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index caa73e3..43e5403 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -109,18 +109,6 @@ class CNNTransformer(nn.Module):
b, t, _ = src.shape
- # Insert sos and eos token.
- # sos_token = self.character_embedding(
- # torch.Tensor([self.vocab_size - 2]).long().to(src.device)
- # )
- # eos_token = self.character_embedding(
- # torch.Tensor([self.vocab_size - 1]).long().to(src.device)
- # )
-
- # sos_tokens = repeat(sos_token, "() h -> b h", b=b).unsqueeze(1)
- # eos_tokens = repeat(eos_token, "() h -> b h", b=b).unsqueeze(1)
- # src = torch.cat((sos_tokens, src, eos_tokens), dim=1)
- # src = torch.cat((sos_tokens, src), dim=1)
src += self.src_position_embedding[:, :t]
return src