summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-10 00:33:27 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-10 00:33:27 +0200
commit7a8c9d40241868f3ecc5c8f5e4d0c06caa6c9e96 (patch)
tree7ae50b4b7e3433e1e8e946edccab45fffc49c501 /text_recognizer
parent181d0e71189c710c374024c7198094a7dfe86044 (diff)
Fix typo in conv transformer
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/networks/conv_transformer.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index d66643b..365906f 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -1,7 +1,6 @@
"""Base network module."""
from typing import Optional, Tuple, Type
-from loguru import logger as log
from torch import nn, Tensor
from text_recognizer.networks.transformer.decoder import Decoder
@@ -42,7 +41,6 @@ class ConvTransformer(nn.Module):
self.token_pos_embedding = token_pos_embedding
else:
self.token_pos_embedding = None
- log.debug("Decoder already have a positional embedding.")
self.pixel_embedding = pixel_embedding
@@ -64,7 +62,7 @@ class ConvTransformer(nn.Module):
def init_weights(self) -> None:
"""Initalize weights for decoder network and to_logits."""
- nn.init.kaiming_normal_(self.token_emb.emb.weight)
+ nn.init.kaiming_normal_(self.token_embedding.weight)
def encode(self, x: Tensor) -> Tensor:
"""Encodes an image into a latent feature vector.
@@ -85,7 +83,7 @@ class ConvTransformer(nn.Module):
"""
z = self.encoder(x)
z = self.conv(z)
- z = self.pixel_pos_embedding(z)
+ z = self.pixel_embedding(z)
z = z.flatten(start_dim=2)
# Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]