From d691b548cd0b6fc4ea184d64261f633789fee021 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Thu, 7 Jan 2021 23:35:42 +0100 Subject: working on vq-vae --- src/text_recognizer/networks/cnn_transformer.py | 12 ------------ 1 file changed, 12 deletions(-) (limited to 'src/text_recognizer/networks/cnn_transformer.py') diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index caa73e3..43e5403 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -109,18 +109,6 @@ class CNNTransformer(nn.Module): b, t, _ = src.shape - # Insert sos and eos token. - # sos_token = self.character_embedding( - # torch.Tensor([self.vocab_size - 2]).long().to(src.device) - # ) - # eos_token = self.character_embedding( - # torch.Tensor([self.vocab_size - 1]).long().to(src.device) - # ) - - # sos_tokens = repeat(sos_token, "() h -> b h", b=b).unsqueeze(1) - # eos_tokens = repeat(eos_token, "() h -> b h", b=b).unsqueeze(1) - # src = torch.cat((sos_tokens, src, eos_tokens), dim=1) - # src = torch.cat((sos_tokens, src), dim=1) src += self.src_position_embedding[:, :t] return src -- cgit v1.2.3-70-g09d2