summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/decoder.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-30 01:12:13 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-30 01:12:13 +0200
commitc614c472707910658b86bb28b9f02062e6982999 (patch)
treebd043a8196f9ee3e5339ec7be17116c0ba0cc1ef /text_recognizer/networks/transformer/decoder.py
parent03029695897fff72c9e7a66a3f986877ebb0b0ff (diff)
Make rotary pos encoding mandatory
Diffstat (limited to 'text_recognizer/networks/transformer/decoder.py')
-rw-r--r--text_recognizer/networks/transformer/decoder.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py
index 741f5b3..09d2dce 100644
--- a/text_recognizer/networks/transformer/decoder.py
+++ b/text_recognizer/networks/transformer/decoder.py
@@ -1,13 +1,11 @@
"""Transformer decoder module."""
from copy import deepcopy
-from typing import Optional, Type
+from typing import Optional
from torch import Tensor, nn
-from text_recognizer.networks.transformer.attention import Attention
from text_recognizer.networks.transformer.decoder_block import DecoderBlock
from text_recognizer.networks.transformer.embeddings.rotary import RotaryEmbedding
-from text_recognizer.networks.transformer.ff import FeedForward
class Decoder(nn.Module):
@@ -18,7 +16,7 @@ class Decoder(nn.Module):
depth: int,
dim: int,
block: DecoderBlock,
- rotary_embedding: Optional[RotaryEmbedding] = None,
+ rotary_embedding: RotaryEmbedding,
) -> None:
super().__init__()
self.depth = depth