diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/transformer/decoder.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py index 1812e40..db5c766 100644 --- a/text_recognizer/networks/transformer/decoder.py +++ b/text_recognizer/networks/transformer/decoder.py @@ -20,6 +20,7 @@ class DecoderBlock(nn.Module): ) -> None: super().__init__() self.layers = ("self_attn", "cross_attn", "ff") + self.has_pos_emb = self_attn.rotary_embedding is not None self.blocks = self._build(self_attn, norm, ff, cross_attn) def _build( @@ -81,10 +82,10 @@ class DecoderBlock(nn.Module): class Decoder(nn.Module): """Decoder Network.""" - def __init__(self, depth: int, has_pos_emb: bool, block: DecoderBlock) -> None: + def __init__(self, depth: int, block: DecoderBlock) -> None: super().__init__() self.depth = depth - self.has_pos_emb = has_pos_emb + self.has_pos_emb = block.has_pos_emb self.blocks = nn.ModuleList([deepcopy(block) for _ in range(self.depth)]) def forward( |