summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/text_decoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/text_decoder.py')
-rw-r--r--text_recognizer/networks/text_decoder.py55
1 files changed, 0 insertions, 55 deletions
diff --git a/text_recognizer/networks/text_decoder.py b/text_recognizer/networks/text_decoder.py
deleted file mode 100644
index 500bcf9..0000000
--- a/text_recognizer/networks/text_decoder.py
+++ /dev/null
@@ -1,55 +0,0 @@
-"""Text decoder."""
-import torch
-from torch import Tensor, nn
-
-from text_recognizer.networks.transformer.decoder import Decoder
-
-
-class TextDecoder(nn.Module):
- """Decodes images to token logits."""
-
- def __init__(
- self,
- dim: int,
- num_classes: int,
- pad_index: Tensor,
- decoder: Decoder,
- ) -> None:
- super().__init__()
- self.dim = dim
- self.num_classes = num_classes
- self.pad_index = pad_index
- self.decoder = decoder
- self.token_embedding = nn.Embedding(
- num_embeddings=self.num_classes, embedding_dim=self.dim
- )
- self.to_logits = nn.Linear(in_features=self.dim, out_features=self.num_classes)
-
- def forward(self, tokens: Tensor, img_features: Tensor) -> Tensor:
- """Decodes latent images embedding into logit tokens.
-
- Args:
- tokens (Tensor): Token indecies.
- img_features (Tensor): Latent images embedding.
-
- Shapes:
- - tokens: :math: `(B, Sy)`
- - img_features: :math: `(B, Sx, D)`
- - logits: :math: `(B, Sy, C)`
-
- where Sy is the length of the output, C is the number of classes
- and D is the hidden dimension.
-
- Returns:
- Tensor: Sequence of logits.
- """
- tokens = tokens.long()
- mask = tokens != self.pad_index
- tokens = self.token_embedding(tokens)
- tokens = self.decoder(x=tokens, context=img_features, mask=mask)
- logits = (
- tokens @ torch.transpose(self.token_embedding.weight.to(tokens.dtype), 0, 1)
- ).float()
- logits = self.to_logits(tokens) # [B, Sy, C]
- logits = logits.permute(0, 2, 1) # [B, C, Sy]
- return logits