summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/transformer/decoder.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py
index 1812e40..db5c766 100644
--- a/text_recognizer/networks/transformer/decoder.py
+++ b/text_recognizer/networks/transformer/decoder.py
@@ -20,6 +20,7 @@ class DecoderBlock(nn.Module):
) -> None:
super().__init__()
self.layers = ("self_attn", "cross_attn", "ff")
+ self.has_pos_emb = self_attn.rotary_embedding is not None
self.blocks = self._build(self_attn, norm, ff, cross_attn)
def _build(
@@ -81,10 +82,10 @@ class DecoderBlock(nn.Module):
class Decoder(nn.Module):
"""Decoder Network."""
- def __init__(self, depth: int, has_pos_emb: bool, block: DecoderBlock) -> None:
+ def __init__(self, depth: int, block: DecoderBlock) -> None:
super().__init__()
self.depth = depth
- self.has_pos_emb = has_pos_emb
+ self.has_pos_emb = block.has_pos_emb
self.blocks = nn.ModuleList([deepcopy(block) for _ in range(self.depth)])
def forward(