summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/cnn_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/cnn_transformer.py')
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py12
1 files changed, 0 insertions, 12 deletions
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index caa73e3..43e5403 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -109,18 +109,6 @@ class CNNTransformer(nn.Module):
b, t, _ = src.shape
- # Insert sos and eos token.
- # sos_token = self.character_embedding(
- # torch.Tensor([self.vocab_size - 2]).long().to(src.device)
- # )
- # eos_token = self.character_embedding(
- # torch.Tensor([self.vocab_size - 1]).long().to(src.device)
- # )
-
- # sos_tokens = repeat(sos_token, "() h -> b h", b=b).unsqueeze(1)
- # eos_tokens = repeat(eos_token, "() h -> b h", b=b).unsqueeze(1)
- # src = torch.cat((sos_tokens, src, eos_tokens), dim=1)
- # src = torch.cat((sos_tokens, src), dim=1)
src += self.src_position_embedding[:, :t]
return src