summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/base.py7
1 files changed, 3 insertions, 4 deletions
diff --git a/text_recognizer/networks/base.py b/text_recognizer/networks/base.py
index f6f1831..29c3bbc 100644
--- a/text_recognizer/networks/base.py
+++ b/text_recognizer/networks/base.py
@@ -5,10 +5,12 @@ from typing import Optional, Tuple, Type
from loguru import logger as log
from torch import nn, Tensor
-from text_recognizer.networks.transformer.layers import Decoder
+from text_recognizer.networks.transformer.decoder import Decoder
class BaseTransformer(nn.Module):
+ """Base transformer network."""
+
def __init__(
self,
input_dims: Tuple[int, int, int],
@@ -39,8 +41,6 @@ class BaseTransformer(nn.Module):
self.token_pos_embedding = None
log.debug("Decoder already have a positional embedding.")
- self.norm = nn.LayerNorm(self.hidden_dim)
-
# Output layer
self.to_logits = nn.Linear(
in_features=self.hidden_dim, out_features=self.num_classes
@@ -76,7 +76,6 @@ class BaseTransformer(nn.Module):
else trg
)
out = self.decoder(x=trg, context=src, input_mask=trg_mask)
- out = self.norm(out)
logits = self.to_logits(out) # [B, Sy, T]
logits = logits.permute(0, 2, 1) # [B, T, Sy]
return logits