summaryrefslogtreecommitdiff
path: root/text_recognizer/network
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/network')
-rw-r--r--text_recognizer/network/vit.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/text_recognizer/network/vit.py b/text_recognizer/network/vit.py
index 80176a8..b6203d7 100644
--- a/text_recognizer/network/vit.py
+++ b/text_recognizer/network/vit.py
@@ -57,12 +57,13 @@ class VisionTransformer(nn.Module):
x += self.patch_embedding.to(img.device, dtype=img.dtype)
return self.encoder(x)
- def decode(self, text: Tensor, context: Tensor) -> Tensor:
+ def decode(self, text: Tensor, img_features: Tensor) -> Tensor:
text = text.long()
+ # TODO: add mask to decoder
mask = text != self.pad_index
tokens = self.token_embedding(text)
tokens = tokens + self.pos_embedding(tokens)
- output = self.decoder(tokens, context)
+ output = self.decoder(tokens, context=img_features)
return self.to_logits(output)
def forward(
@@ -71,6 +72,6 @@ class VisionTransformer(nn.Module):
text: Tensor,
) -> Tensor:
"""Applies decoder block on input signals."""
- context = self.encode(img)
- logits = self.decode(text, context)
- return logits.permute(0, 2, 1)
+ img_features = self.encode(img)
+ logits = self.decode(text, img_features)
+ return logits # [B, N, C]