summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/text_decoder.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-30 01:12:13 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-30 01:12:13 +0200
commitc614c472707910658b86bb28b9f02062e6982999 (patch)
treebd043a8196f9ee3e5339ec7be17116c0ba0cc1ef /text_recognizer/networks/text_decoder.py
parent03029695897fff72c9e7a66a3f986877ebb0b0ff (diff)
Make rotary pos encoding mandatory
Diffstat (limited to 'text_recognizer/networks/text_decoder.py')
-rw-r--r--text_recognizer/networks/text_decoder.py5
1 files changed, 1 insertions, 4 deletions
diff --git a/text_recognizer/networks/text_decoder.py b/text_recognizer/networks/text_decoder.py
index c054b41..7ee6720 100644
--- a/text_recognizer/networks/text_decoder.py
+++ b/text_recognizer/networks/text_decoder.py
@@ -1,5 +1,5 @@
"""Text decoder."""
-from typing import Type
+from typing import Optional, Type
import torch
from torch import Tensor, nn
@@ -16,7 +16,6 @@ class TextDecoder(nn.Module):
num_classes: int,
pad_index: Tensor,
decoder: Decoder,
- token_pos_embedding: Type[nn.Module],
) -> None:
super().__init__()
self.hidden_dim = hidden_dim
@@ -26,7 +25,6 @@ class TextDecoder(nn.Module):
self.token_embedding = nn.Embedding(
num_embeddings=self.num_classes, embedding_dim=self.hidden_dim
)
- self.token_pos_embedding = token_pos_embedding
self.to_logits = nn.Linear(
in_features=self.hidden_dim, out_features=self.num_classes
)
@@ -52,7 +50,6 @@ class TextDecoder(nn.Module):
tokens = tokens.long()
mask = tokens != self.pad_index
tokens = self.token_embedding(tokens)
- tokens = tokens + self.token_pos_embedding(tokens)
tokens = self.decoder(x=tokens, context=img_features, mask=mask)
logits = (
tokens @ torch.transpose(self.token_embedding.weight.to(tokens.dtype), 0, 1)