diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-09-02 01:53:20 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-09-02 01:53:20 +0200 |
commit | 617bf7f0285090b85817a398ef4bb871d4f616e9 (patch) | |
tree | 57ea13d9f0bf8cf9fc535126338cbe725b8a89bf /text_recognizer/network | |
parent | 27b001503f068a89acc40cc960a8b54feb1bddc3 (diff) |
Rename context
Diffstat (limited to 'text_recognizer/network')
-rw-r--r-- | text_recognizer/network/vit.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/text_recognizer/network/vit.py b/text_recognizer/network/vit.py index 80176a8..b6203d7 100644 --- a/text_recognizer/network/vit.py +++ b/text_recognizer/network/vit.py @@ -57,12 +57,13 @@ class VisionTransformer(nn.Module): x += self.patch_embedding.to(img.device, dtype=img.dtype) return self.encoder(x) - def decode(self, text: Tensor, context: Tensor) -> Tensor: + def decode(self, text: Tensor, img_features: Tensor) -> Tensor: text = text.long() + # TODO: add mask to decoder mask = text != self.pad_index tokens = self.token_embedding(text) tokens = tokens + self.pos_embedding(tokens) - output = self.decoder(tokens, context) + output = self.decoder(tokens, context=img_features) return self.to_logits(output) def forward( @@ -71,6 +72,6 @@ class VisionTransformer(nn.Module): text: Tensor, ) -> Tensor: """Applies decoder block on input signals.""" - context = self.encode(img) - logits = self.decode(text, context) - return logits.permute(0, 2, 1) + img_features = self.encode(img) + logits = self.decode(text, img_features) + return logits # [B, N, C] |