diff options
-rw-r--r-- | text_recognizer/network/vit.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/text_recognizer/network/vit.py b/text_recognizer/network/vit.py index 80176a8..b6203d7 100644 --- a/text_recognizer/network/vit.py +++ b/text_recognizer/network/vit.py @@ -57,12 +57,13 @@ class VisionTransformer(nn.Module): x += self.patch_embedding.to(img.device, dtype=img.dtype) return self.encoder(x) - def decode(self, text: Tensor, context: Tensor) -> Tensor: + def decode(self, text: Tensor, img_features: Tensor) -> Tensor: text = text.long() + # TODO: add mask to decoder mask = text != self.pad_index tokens = self.token_embedding(text) tokens = tokens + self.pos_embedding(tokens) - output = self.decoder(tokens, context) + output = self.decoder(tokens, context=img_features) return self.to_logits(output) def forward( @@ -71,6 +72,6 @@ class VisionTransformer(nn.Module): text: Tensor, ) -> Tensor: """Applies decoder block on input signals.""" - context = self.encode(img) - logits = self.decode(text, context) - return logits.permute(0, 2, 1) + img_features = self.encode(img) + logits = self.decode(text, img_features) + return logits # [B, N, C] |