summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/decoder_block.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-27 00:10:00 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-27 00:10:00 +0200
commitbf680dce6bc7dcadd20923a193fc9ab8fbd0a0c6 (patch)
tree5679e4d28673b8898769266f8a8d856a4d924b07 /text_recognizer/networks/transformer/decoder_block.py
parent096823d111117ac5efe954db7f3db26cccabda6c (diff)
Refactor decoder block
Diffstat (limited to 'text_recognizer/networks/transformer/decoder_block.py')
-rw-r--r--text_recognizer/networks/transformer/decoder_block.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/decoder_block.py b/text_recognizer/networks/transformer/decoder_block.py
new file mode 100644
index 0000000..e6e7fb8
--- /dev/null
+++ b/text_recognizer/networks/transformer/decoder_block.py
@@ -0,0 +1,46 @@
+"""Transformer decoder module."""
+from copy import deepcopy
+from typing import Optional, Type
+
+from torch import Tensor, nn
+
+from text_recognizer.networks.transformer.attention import Attention
+from text_recognizer.networks.transformer.ff import FeedForward
+
+
+class DecoderBlock(nn.Module):
+ """Decoder block."""
+
+ def __init__(
+ self,
+ self_attn: Attention,
+ norm: Type[nn.Module],
+ ff: FeedForward,
+ cross_attn: Optional[Attention] = None,
+ ) -> None:
+ super().__init__()
+ self.ln_attn = norm
+ self.attn = self_attn
+ self.ln_cross_attn = deepcopy(norm)
+ self.cross_attn = cross_attn
+ self.ln_ff = deepcopy(norm)
+ self.ff = ff
+ self.has_pos_emb = self.attn.rotary_embedding is not None
+
+ def forward(
+ self,
+ x: Tensor,
+ context: Optional[Tensor] = None,
+ input_mask: Optional[Tensor] = None,
+ context_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Applies decoder block on input signals."""
+ x = x + self.attn(self.ln_attn(x), input_mask=input_mask)
+ x += self.cross_attn(
+ x=self.ln_cross_attn(x),
+ context=context,
+ input_mask=input_mask,
+ context_mask=context_mask,
+ )
+ x += self.ff(self.ln_ff(x))
+ return x