diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-07 23:35:42 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-07 23:35:42 +0100 |
commit | d691b548cd0b6fc4ea184d64261f633789fee021 (patch) | |
tree | 99e2fc5481ce102d5655b65681274e5f0286306f /src/text_recognizer/networks/cnn_transformer.py | |
parent | ff9a21d333f11a42e67c1963ed67de9c0fda87c9 (diff) |
working on vq-vae
Diffstat (limited to 'src/text_recognizer/networks/cnn_transformer.py')
-rw-r--r-- | src/text_recognizer/networks/cnn_transformer.py | 12 |
1 files changed, 0 insertions, 12 deletions
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index caa73e3..43e5403 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -109,18 +109,6 @@ class CNNTransformer(nn.Module): b, t, _ = src.shape - # Insert sos and eos token. - # sos_token = self.character_embedding( - # torch.Tensor([self.vocab_size - 2]).long().to(src.device) - # ) - # eos_token = self.character_embedding( - # torch.Tensor([self.vocab_size - 1]).long().to(src.device) - # ) - - # sos_tokens = repeat(sos_token, "() h -> b h", b=b).unsqueeze(1) - # eos_tokens = repeat(eos_token, "() h -> b h", b=b).unsqueeze(1) - # src = torch.cat((sos_tokens, src, eos_tokens), dim=1) - # src = torch.cat((sos_tokens, src), dim=1) src += self.src_position_embedding[:, :t] return src |