diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 00:10:00 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 00:10:00 +0200 |
commit | bf680dce6bc7dcadd20923a193fc9ab8fbd0a0c6 (patch) | |
tree | 5679e4d28673b8898769266f8a8d856a4d924b07 /text_recognizer/networks/transformer/decoder_block.py | |
parent | 096823d111117ac5efe954db7f3db26cccabda6c (diff) |
Refactor decoder block
Diffstat (limited to 'text_recognizer/networks/transformer/decoder_block.py')
-rw-r--r-- | text_recognizer/networks/transformer/decoder_block.py | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/decoder_block.py b/text_recognizer/networks/transformer/decoder_block.py new file mode 100644 index 0000000..e6e7fb8 --- /dev/null +++ b/text_recognizer/networks/transformer/decoder_block.py @@ -0,0 +1,46 @@ +"""Transformer decoder module.""" +from copy import deepcopy +from typing import Optional, Type + +from torch import Tensor, nn + +from text_recognizer.networks.transformer.attention import Attention +from text_recognizer.networks.transformer.ff import FeedForward + + +class DecoderBlock(nn.Module): + """Decoder block.""" + + def __init__( + self, + self_attn: Attention, + norm: Type[nn.Module], + ff: FeedForward, + cross_attn: Optional[Attention] = None, + ) -> None: + super().__init__() + self.ln_attn = norm + self.attn = self_attn + self.ln_cross_attn = deepcopy(norm) + self.cross_attn = cross_attn + self.ln_ff = deepcopy(norm) + self.ff = ff + self.has_pos_emb = self.attn.rotary_embedding is not None + + def forward( + self, + x: Tensor, + context: Optional[Tensor] = None, + input_mask: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + ) -> Tensor: + """Applies decoder block on input signals.""" + x = x + self.attn(self.ln_attn(x), input_mask=input_mask) + x += self.cross_attn( + x=self.ln_cross_attn(x), + context=context, + input_mask=input_mask, + context_mask=context_mask, + ) + x += self.ff(self.ln_ff(x)) + return x |