summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/conv_transformer.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index d66643b..365906f 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -1,7 +1,6 @@
"""Base network module."""
from typing import Optional, Tuple, Type
-from loguru import logger as log
from torch import nn, Tensor
from text_recognizer.networks.transformer.decoder import Decoder
@@ -42,7 +41,6 @@ class ConvTransformer(nn.Module):
self.token_pos_embedding = token_pos_embedding
else:
self.token_pos_embedding = None
- log.debug("Decoder already have a positional embedding.")
self.pixel_embedding = pixel_embedding
@@ -64,7 +62,7 @@ class ConvTransformer(nn.Module):
def init_weights(self) -> None:
"""Initalize weights for decoder network and to_logits."""
- nn.init.kaiming_normal_(self.token_emb.emb.weight)
+ nn.init.kaiming_normal_(self.token_embedding.weight)
def encode(self, x: Tensor) -> Tensor:
"""Encodes an image into a latent feature vector.
@@ -85,7 +83,7 @@ class ConvTransformer(nn.Module):
"""
z = self.encoder(x)
z = self.conv(z)
- z = self.pixel_pos_embedding(z)
+ z = self.pixel_embedding(z)
z = z.flatten(start_dim=2)
# Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]