From 7a8c9d40241868f3ecc5c8f5e4d0c06caa6c9e96 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 10 Jun 2022 00:33:27 +0200 Subject: Fix typo in conv transformer --- text_recognizer/networks/conv_transformer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) (limited to 'text_recognizer') diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index d66643b..365906f 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,7 +1,6 @@ """Base network module.""" from typing import Optional, Tuple, Type -from loguru import logger as log from torch import nn, Tensor from text_recognizer.networks.transformer.decoder import Decoder @@ -42,7 +41,6 @@ class ConvTransformer(nn.Module): self.token_pos_embedding = token_pos_embedding else: self.token_pos_embedding = None - log.debug("Decoder already have a positional embedding.") self.pixel_embedding = pixel_embedding @@ -64,7 +62,7 @@ class ConvTransformer(nn.Module): def init_weights(self) -> None: """Initalize weights for decoder network and to_logits.""" - nn.init.kaiming_normal_(self.token_emb.emb.weight) + nn.init.kaiming_normal_(self.token_embedding.weight) def encode(self, x: Tensor) -> Tensor: """Encodes an image into a latent feature vector. @@ -85,7 +83,7 @@ class ConvTransformer(nn.Module): """ z = self.encoder(x) z = self.conv(z) - z = self.pixel_pos_embedding(z) + z = self.pixel_embedding(z) z = z.flatten(start_dim=2) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] -- cgit v1.2.3-70-g09d2