diff options
Diffstat (limited to 'text_recognizer/networks')
| -rw-r--r-- | text_recognizer/networks/conv_transformer.py | 6 | 
1 files changed, 2 insertions, 4 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index d66643b..365906f 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,7 +1,6 @@  """Base network module."""  from typing import Optional, Tuple, Type -from loguru import logger as log  from torch import nn, Tensor  from text_recognizer.networks.transformer.decoder import Decoder @@ -42,7 +41,6 @@ class ConvTransformer(nn.Module):              self.token_pos_embedding = token_pos_embedding          else:              self.token_pos_embedding = None -            log.debug("Decoder already have a positional embedding.")          self.pixel_embedding = pixel_embedding @@ -64,7 +62,7 @@ class ConvTransformer(nn.Module):      def init_weights(self) -> None:          """Initalize weights for decoder network and to_logits.""" -        nn.init.kaiming_normal_(self.token_emb.emb.weight) +        nn.init.kaiming_normal_(self.token_embedding.weight)      def encode(self, x: Tensor) -> Tensor:          """Encodes an image into a latent feature vector. @@ -85,7 +83,7 @@ class ConvTransformer(nn.Module):          """          z = self.encoder(x)          z = self.conv(z) -        z = self.pixel_pos_embedding(z) +        z = self.pixel_embedding(z)          z = z.flatten(start_dim=2)          # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]  |