diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-10 00:34:20 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-10 00:34:20 +0200 |
commit | 7eb0002f599367a5b9a80374c89e08d7a93d6a1b (patch) | |
tree | 7af788a84cd387c42c26d74d4c273193c99151ed /text_recognizer/networks | |
parent | e193ca9d94456c1933e25d2f8b7e8224d3e92ae3 (diff) |
Fix check for pos emb
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/transformer/decoder.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py index 1812e40..db5c766 100644 --- a/text_recognizer/networks/transformer/decoder.py +++ b/text_recognizer/networks/transformer/decoder.py @@ -20,6 +20,7 @@ class DecoderBlock(nn.Module): ) -> None: super().__init__() self.layers = ("self_attn", "cross_attn", "ff") + self.has_pos_emb = self_attn.rotary_embedding is not None self.blocks = self._build(self_attn, norm, ff, cross_attn) def _build( @@ -81,10 +82,10 @@ class DecoderBlock(nn.Module): class Decoder(nn.Module): """Decoder Network.""" - def __init__(self, depth: int, has_pos_emb: bool, block: DecoderBlock) -> None: + def __init__(self, depth: int, block: DecoderBlock) -> None: super().__init__() self.depth = depth - self.has_pos_emb = has_pos_emb + self.has_pos_emb = block.has_pos_emb self.blocks = nn.ModuleList([deepcopy(block) for _ in range(self.depth)]) def forward( |