From 617bf7f0285090b85817a398ef4bb871d4f616e9 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 2 Sep 2023 01:53:20 +0200 Subject: Rename context --- text_recognizer/network/vit.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'text_recognizer/network') diff --git a/text_recognizer/network/vit.py b/text_recognizer/network/vit.py index 80176a8..b6203d7 100644 --- a/text_recognizer/network/vit.py +++ b/text_recognizer/network/vit.py @@ -57,12 +57,13 @@ class VisionTransformer(nn.Module): x += self.patch_embedding.to(img.device, dtype=img.dtype) return self.encoder(x) - def decode(self, text: Tensor, context: Tensor) -> Tensor: + def decode(self, text: Tensor, img_features: Tensor) -> Tensor: text = text.long() + # TODO: add mask to decoder mask = text != self.pad_index tokens = self.token_embedding(text) tokens = tokens + self.pos_embedding(tokens) - output = self.decoder(tokens, context) + output = self.decoder(tokens, context=img_features) return self.to_logits(output) def forward( @@ -71,6 +72,6 @@ class VisionTransformer(nn.Module): text: Tensor, ) -> Tensor: """Applies decoder block on input signals.""" - context = self.encode(img) - logits = self.decode(text, context) - return logits.permute(0, 2, 1) + img_features = self.encode(img) + logits = self.decode(text, img_features) + return logits # [B, N, C] -- cgit v1.2.3-70-g09d2