summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/decoder_block.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/decoder_block.py')
-rw-r--r--text_recognizer/networks/transformer/decoder_block.py44
1 files changed, 0 insertions, 44 deletions
diff --git a/text_recognizer/networks/transformer/decoder_block.py b/text_recognizer/networks/transformer/decoder_block.py
deleted file mode 100644
index b8eb5c4..0000000
--- a/text_recognizer/networks/transformer/decoder_block.py
+++ /dev/null
@@ -1,44 +0,0 @@
-"""Transformer decoder module."""
-from copy import deepcopy
-from typing import Optional, Type
-
-from torch import Tensor, nn
-
-from text_recognizer.networks.transformer.attention import Attention
-from text_recognizer.networks.transformer.embeddings.rotary import RotaryEmbedding
-from text_recognizer.networks.transformer.ff import FeedForward
-
-
-class DecoderBlock(nn.Module):
- """Residual decoder block."""
-
- def __init__(
- self,
- self_attn: Attention,
- norm: Type[nn.Module],
- ff: FeedForward,
- cross_attn: Optional[Attention] = None,
- ) -> None:
- super().__init__()
- self.ln_attn = norm
- self.attn = self_attn
- self.ln_cross_attn = deepcopy(norm)
- self.cross_attn = cross_attn
- self.ln_ff = deepcopy(norm)
- self.ff = ff
-
- def forward(
- self,
- x: Tensor,
- rotary_embedding: RotaryEmbedding,
- context: Optional[Tensor] = None,
- mask: Optional[Tensor] = None,
- ) -> Tensor:
- """Applies decoder block on input signals."""
- x = x + self.attn(self.ln_attn(x), mask=mask, rotary_embedding=rotary_embedding)
- x = x + self.cross_attn(
- x=self.ln_cross_attn(x),
- context=context,
- )
- x = x + self.ff(self.ln_ff(x))
- return x