diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-01-29 15:56:26 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-01-29 15:56:26 +0100 |
commit | eae5ca3e561eff0a8adfe4f27f13ccc49691e468 (patch) | |
tree | 18ebf1eac18d4653fc2aae763415fa0c1b6da0ae /text_recognizer | |
parent | c2dd53291e34f2ca75c8dbcd9b0653899682fae4 (diff) |
feat(base): remove output norm
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/base.py | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/text_recognizer/networks/base.py b/text_recognizer/networks/base.py index f6f1831..29c3bbc 100644 --- a/text_recognizer/networks/base.py +++ b/text_recognizer/networks/base.py @@ -5,10 +5,12 @@ from typing import Optional, Tuple, Type from loguru import logger as log from torch import nn, Tensor -from text_recognizer.networks.transformer.layers import Decoder +from text_recognizer.networks.transformer.decoder import Decoder class BaseTransformer(nn.Module): + """Base transformer network.""" + def __init__( self, input_dims: Tuple[int, int, int], @@ -39,8 +41,6 @@ class BaseTransformer(nn.Module): self.token_pos_embedding = None log.debug("Decoder already have a positional embedding.") - self.norm = nn.LayerNorm(self.hidden_dim) - # Output layer self.to_logits = nn.Linear( in_features=self.hidden_dim, out_features=self.num_classes @@ -76,7 +76,6 @@ class BaseTransformer(nn.Module): else trg ) out = self.decoder(x=trg, context=src, input_mask=trg_mask) - out = self.norm(out) logits = self.to_logits(out) # [B, Sy, T] logits = logits.permute(0, 2, 1) # [B, T, Sy] return logits |