summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/decoder.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-10 00:34:20 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-10 00:34:20 +0200
commit7eb0002f599367a5b9a80374c89e08d7a93d6a1b (patch)
tree7af788a84cd387c42c26d74d4c273193c99151ed /text_recognizer/networks/transformer/decoder.py
parente193ca9d94456c1933e25d2f8b7e8224d3e92ae3 (diff)
Fix check for pos emb
Diffstat (limited to 'text_recognizer/networks/transformer/decoder.py')
-rw-r--r--text_recognizer/networks/transformer/decoder.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py
index 1812e40..db5c766 100644
--- a/text_recognizer/networks/transformer/decoder.py
+++ b/text_recognizer/networks/transformer/decoder.py
@@ -20,6 +20,7 @@ class DecoderBlock(nn.Module):
) -> None:
super().__init__()
self.layers = ("self_attn", "cross_attn", "ff")
+ self.has_pos_emb = self_attn.rotary_embedding is not None
self.blocks = self._build(self_attn, norm, ff, cross_attn)
def _build(
@@ -81,10 +82,10 @@ class DecoderBlock(nn.Module):
class Decoder(nn.Module):
"""Decoder Network."""
- def __init__(self, depth: int, has_pos_emb: bool, block: DecoderBlock) -> None:
+ def __init__(self, depth: int, block: DecoderBlock) -> None:
super().__init__()
self.depth = depth
- self.has_pos_emb = has_pos_emb
+ self.has_pos_emb = block.has_pos_emb
self.blocks = nn.ModuleList([deepcopy(block) for _ in range(self.depth)])
def forward(