diff options
Diffstat (limited to 'text_recognizer')
| -rw-r--r-- | text_recognizer/networks/conv_transformer.py | 15 | 
1 files changed, 9 insertions, 6 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index f07b97d..ff98ec6 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -57,8 +57,10 @@ class ConvTransformer(nn.Module):              self.token_pos_embedding = None              log.debug("Decoder already have a positional embedding.") -        # Head -        self.head = nn.Linear( +        self.norm = nn.LayerNorm(self.hidden_dim) + +        # Output layer +        self.to_logits = nn.Linear(              in_features=self.hidden_dim, out_features=self.num_classes          ) @@ -66,11 +68,11 @@ class ConvTransformer(nn.Module):          self.init_weights()      def init_weights(self) -> None: -        """Initalize weights for decoder network and head.""" +        """Initalize weights for decoder network and to_logits."""          bound = 0.1          self.token_embedding.weight.data.uniform_(-bound, bound) -        self.head.bias.data.zero_() -        self.head.weight.data.uniform_(-bound, bound) +        self.to_logits.bias.data.zero_() +        self.to_logits.weight.data.uniform_(-bound, bound)      def encode(self, x: Tensor) -> Tensor:          """Encodes an image into a latent feature vector. @@ -125,7 +127,8 @@ class ConvTransformer(nn.Module):              else trg          )          out = self.decoder(x=trg, context=src, input_mask=trg_mask) -        logits = self.head(out)  # [B, Sy, T] +        out = self.norm(out) +        logits = self.to_logits(out)  # [B, Sy, T]          logits = logits.permute(0, 2, 1)  # [B, T, Sy]          return logits  |