diff options
Diffstat (limited to 'text_recognizer/networks/text_decoder.py')
-rw-r--r-- | text_recognizer/networks/text_decoder.py | 5 |
1 files changed, 1 insertions, 4 deletions
diff --git a/text_recognizer/networks/text_decoder.py b/text_recognizer/networks/text_decoder.py index c054b41..7ee6720 100644 --- a/text_recognizer/networks/text_decoder.py +++ b/text_recognizer/networks/text_decoder.py @@ -1,5 +1,5 @@ """Text decoder.""" -from typing import Type +from typing import Optional, Type import torch from torch import Tensor, nn @@ -16,7 +16,6 @@ class TextDecoder(nn.Module): num_classes: int, pad_index: Tensor, decoder: Decoder, - token_pos_embedding: Type[nn.Module], ) -> None: super().__init__() self.hidden_dim = hidden_dim @@ -26,7 +25,6 @@ class TextDecoder(nn.Module): self.token_embedding = nn.Embedding( num_embeddings=self.num_classes, embedding_dim=self.hidden_dim ) - self.token_pos_embedding = token_pos_embedding self.to_logits = nn.Linear( in_features=self.hidden_dim, out_features=self.num_classes ) @@ -52,7 +50,6 @@ class TextDecoder(nn.Module): tokens = tokens.long() mask = tokens != self.pad_index tokens = self.token_embedding(tokens) - tokens = tokens + self.token_pos_embedding(tokens) tokens = self.decoder(x=tokens, context=img_features, mask=mask) logits = ( tokens @ torch.transpose(self.token_embedding.weight.to(tokens.dtype), 0, 1) |