From 7eb0002f599367a5b9a80374c89e08d7a93d6a1b Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 10 Jun 2022 00:34:20 +0200 Subject: Fix check for pos emb --- text_recognizer/networks/transformer/decoder.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py index 1812e40..db5c766 100644 --- a/text_recognizer/networks/transformer/decoder.py +++ b/text_recognizer/networks/transformer/decoder.py @@ -20,6 +20,7 @@ class DecoderBlock(nn.Module): ) -> None: super().__init__() self.layers = ("self_attn", "cross_attn", "ff") + self.has_pos_emb = self_attn.rotary_embedding is not None self.blocks = self._build(self_attn, norm, ff, cross_attn) def _build( @@ -81,10 +82,10 @@ class DecoderBlock(nn.Module): class Decoder(nn.Module): """Decoder Network.""" - def __init__(self, depth: int, has_pos_emb: bool, block: DecoderBlock) -> None: + def __init__(self, depth: int, block: DecoderBlock) -> None: super().__init__() self.depth = depth - self.has_pos_emb = has_pos_emb + self.has_pos_emb = block.has_pos_emb self.blocks = nn.ModuleList([deepcopy(block) for _ in range(self.depth)]) def forward( -- cgit v1.2.3-70-g09d2