diff options
Diffstat (limited to 'text_recognizer/networks/transformer/decoder_block.py')
-rw-r--r-- | text_recognizer/networks/transformer/decoder_block.py | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/decoder_block.py b/text_recognizer/networks/transformer/decoder_block.py new file mode 100644 index 0000000..e6e7fb8 --- /dev/null +++ b/text_recognizer/networks/transformer/decoder_block.py @@ -0,0 +1,46 @@ +"""Transformer decoder module.""" +from copy import deepcopy +from typing import Optional, Type + +from torch import Tensor, nn + +from text_recognizer.networks.transformer.attention import Attention +from text_recognizer.networks.transformer.ff import FeedForward + + +class DecoderBlock(nn.Module): + """Decoder block.""" + + def __init__( + self, + self_attn: Attention, + norm: Type[nn.Module], + ff: FeedForward, + cross_attn: Optional[Attention] = None, + ) -> None: + super().__init__() + self.ln_attn = norm + self.attn = self_attn + self.ln_cross_attn = deepcopy(norm) + self.cross_attn = cross_attn + self.ln_ff = deepcopy(norm) + self.ff = ff + self.has_pos_emb = self.attn.rotary_embedding is not None + + def forward( + self, + x: Tensor, + context: Optional[Tensor] = None, + input_mask: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + ) -> Tensor: + """Applies decoder block on input signals.""" + x = x + self.attn(self.ln_attn(x), input_mask=input_mask) + x += self.cross_attn( + x=self.ln_cross_attn(x), + context=context, + input_mask=input_mask, + context_mask=context_mask, + ) + x += self.ff(self.ln_ff(x)) + return x |