diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-28 21:37:15 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-28 21:37:15 +0100 |
commit | 662f3c6559e23914309f63472035355f4098091c (patch) | |
tree | d4586b847e05d6677a3ab0b284e6deb7b2090ccb /text_recognizer | |
parent | 82a8efc3ba5dd2048b3b46e59c2da0face44fed1 (diff) |
Add norm layer to output from decoder
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index f07b97d..ff98ec6 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -57,8 +57,10 @@ class ConvTransformer(nn.Module): self.token_pos_embedding = None log.debug("Decoder already have a positional embedding.") - # Head - self.head = nn.Linear( + self.norm = nn.LayerNorm(self.hidden_dim) + + # Output layer + self.to_logits = nn.Linear( in_features=self.hidden_dim, out_features=self.num_classes ) @@ -66,11 +68,11 @@ class ConvTransformer(nn.Module): self.init_weights() def init_weights(self) -> None: - """Initalize weights for decoder network and head.""" + """Initalize weights for decoder network and to_logits.""" bound = 0.1 self.token_embedding.weight.data.uniform_(-bound, bound) - self.head.bias.data.zero_() - self.head.weight.data.uniform_(-bound, bound) + self.to_logits.bias.data.zero_() + self.to_logits.weight.data.uniform_(-bound, bound) def encode(self, x: Tensor) -> Tensor: """Encodes an image into a latent feature vector. @@ -125,7 +127,8 @@ class ConvTransformer(nn.Module): else trg ) out = self.decoder(x=trg, context=src, input_mask=trg_mask) - logits = self.head(out) # [B, Sy, T] + out = self.norm(out) + logits = self.to_logits(out) # [B, Sy, T] logits = logits.permute(0, 2, 1) # [B, T, Sy] return logits |