summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conv_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r--text_recognizer/networks/conv_transformer.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index f07b97d..ff98ec6 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -57,8 +57,10 @@ class ConvTransformer(nn.Module):
self.token_pos_embedding = None
log.debug("Decoder already have a positional embedding.")
- # Head
- self.head = nn.Linear(
+ self.norm = nn.LayerNorm(self.hidden_dim)
+
+ # Output layer
+ self.to_logits = nn.Linear(
in_features=self.hidden_dim, out_features=self.num_classes
)
@@ -66,11 +68,11 @@ class ConvTransformer(nn.Module):
self.init_weights()
def init_weights(self) -> None:
- """Initalize weights for decoder network and head."""
+ """Initalize weights for decoder network and to_logits."""
bound = 0.1
self.token_embedding.weight.data.uniform_(-bound, bound)
- self.head.bias.data.zero_()
- self.head.weight.data.uniform_(-bound, bound)
+ self.to_logits.bias.data.zero_()
+ self.to_logits.weight.data.uniform_(-bound, bound)
def encode(self, x: Tensor) -> Tensor:
"""Encodes an image into a latent feature vector.
@@ -125,7 +127,8 @@ class ConvTransformer(nn.Module):
else trg
)
out = self.decoder(x=trg, context=src, input_mask=trg_mask)
- logits = self.head(out) # [B, Sy, T]
+ out = self.norm(out)
+ logits = self.to_logits(out) # [B, Sy, T]
logits = logits.permute(0, 2, 1) # [B, T, Sy]
return logits