diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index d66643b..365906f 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,7 +1,6 @@ """Base network module.""" from typing import Optional, Tuple, Type -from loguru import logger as log from torch import nn, Tensor from text_recognizer.networks.transformer.decoder import Decoder @@ -42,7 +41,6 @@ class ConvTransformer(nn.Module): self.token_pos_embedding = token_pos_embedding else: self.token_pos_embedding = None - log.debug("Decoder already have a positional embedding.") self.pixel_embedding = pixel_embedding @@ -64,7 +62,7 @@ class ConvTransformer(nn.Module): def init_weights(self) -> None: """Initalize weights for decoder network and to_logits.""" - nn.init.kaiming_normal_(self.token_emb.emb.weight) + nn.init.kaiming_normal_(self.token_embedding.weight) def encode(self, x: Tensor) -> Tensor: """Encodes an image into a latent feature vector. @@ -85,7 +83,7 @@ class ConvTransformer(nn.Module): """ z = self.encoder(x) z = self.conv(z) - z = self.pixel_pos_embedding(z) + z = self.pixel_embedding(z) z = z.flatten(start_dim=2) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] |