diff options
Diffstat (limited to 'text_recognizer/networks/transformer/decoder.py')
-rw-r--r-- | text_recognizer/networks/transformer/decoder.py | 19 |
1 files changed, 14 insertions, 5 deletions
diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py index c7da226..741f5b3 100644 --- a/text_recognizer/networks/transformer/decoder.py +++ b/text_recognizer/networks/transformer/decoder.py @@ -6,16 +6,23 @@ from torch import Tensor, nn from text_recognizer.networks.transformer.attention import Attention from text_recognizer.networks.transformer.decoder_block import DecoderBlock +from text_recognizer.networks.transformer.embeddings.rotary import RotaryEmbedding from text_recognizer.networks.transformer.ff import FeedForward class Decoder(nn.Module): """Decoder Network.""" - def __init__(self, depth: int, dim: int, block: DecoderBlock) -> None: + def __init__( + self, + depth: int, + dim: int, + block: DecoderBlock, + rotary_embedding: Optional[RotaryEmbedding] = None, + ) -> None: super().__init__() self.depth = depth - self.has_pos_emb = block.has_pos_emb + self.rotary_embedding = rotary_embedding self.blocks = nn.ModuleList([deepcopy(block) for _ in range(self.depth)]) self.ln = nn.LayerNorm(dim) @@ -23,12 +30,14 @@ class Decoder(nn.Module): self, x: Tensor, context: Optional[Tensor] = None, - input_mask: Optional[Tensor] = None, - context_mask: Optional[Tensor] = None, + mask: Optional[Tensor] = None, ) -> Tensor: """Applies the network to the signals.""" for block in self.blocks: x = block( - x=x, context=context, input_mask=input_mask, context_mask=context_mask + x=x, + context=context, + mask=mask, + rotary_embedding=self.rotary_embedding, ) return self.ln(x) |