diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-30 01:12:13 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-30 01:12:13 +0200 |
commit | c614c472707910658b86bb28b9f02062e6982999 (patch) | |
tree | bd043a8196f9ee3e5339ec7be17116c0ba0cc1ef /text_recognizer/networks/transformer/decoder.py | |
parent | 03029695897fff72c9e7a66a3f986877ebb0b0ff (diff) |
Make rotary pos encoding mandatory
Diffstat (limited to 'text_recognizer/networks/transformer/decoder.py')
-rw-r--r-- | text_recognizer/networks/transformer/decoder.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py index 741f5b3..09d2dce 100644 --- a/text_recognizer/networks/transformer/decoder.py +++ b/text_recognizer/networks/transformer/decoder.py @@ -1,13 +1,11 @@ """Transformer decoder module.""" from copy import deepcopy -from typing import Optional, Type +from typing import Optional from torch import Tensor, nn -from text_recognizer.networks.transformer.attention import Attention from text_recognizer.networks.transformer.decoder_block import DecoderBlock from text_recognizer.networks.transformer.embeddings.rotary import RotaryEmbedding -from text_recognizer.networks.transformer.ff import FeedForward class Decoder(nn.Module): @@ -18,7 +16,7 @@ class Decoder(nn.Module): depth: int, dim: int, block: DecoderBlock, - rotary_embedding: Optional[RotaryEmbedding] = None, + rotary_embedding: RotaryEmbedding, ) -> None: super().__init__() self.depth = depth |