summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/decoder.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-01-29 15:52:35 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-01-29 15:52:35 +0100
commit877ee5984dad08379e0c781d35a534b97012e325 (patch)
tree0d74fd324efa4552f9d659f00c5c5be8119c05a5 /text_recognizer/networks/transformer/decoder.py
parent7694f70ec78d748694f818ad9d10ca46c1f04a96 (diff)
feat: add new transformer decoder
Diffstat (limited to 'text_recognizer/networks/transformer/decoder.py')
-rw-r--r--text_recognizer/networks/transformer/decoder.py100
1 files changed, 100 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py
new file mode 100644
index 0000000..a58f7bd
--- /dev/null
+++ b/text_recognizer/networks/transformer/decoder.py
@@ -0,0 +1,100 @@
+"""Transformer decoder module."""
+from copy import deepcopy
+from typing import Optional, Tuple, Type
+
+from torch import nn, Tensor
+
+from text_recognizer.networks.transformer.attention import Attention
+from text_recognizer.networks.transformer.mlp import FeedForward
+
+
+class DecoderBlock(nn.Module):
+ """Decoder block."""
+
+ def __init__(
+ self,
+ self_attn: Attention,
+ norm: Type[nn.Module],
+ ff: FeedForward,
+ cross_attn: Optional[Attention] = None,
+ ) -> None:
+ super().__init__()
+ self._layers = ("self_attn", "cross_attn", "ff")
+ self._blocks = self._build(self_attn, norm, ff, cross_attn)
+
+ def _build(
+ self,
+ self_attn: Attention,
+ norm: Type[nn.Module],
+ ff: FeedForward,
+ cross_attn: Optional[Attention],
+ ) -> nn.ModuleDict:
+ return nn.ModuleDict(
+ {
+ self.layers[0]: nn.ModuleList([norm, self_attn]),
+ self.layers[1]: nn.ModuleList([deepcopy(norm), cross_attn]),
+ self.layers[2]: nn.ModuleList([deepcopy(norm), ff]),
+ }
+ )
+
+ def _apply(
+ self,
+ layer: str,
+ x: Tensor,
+ context: Optional[Tensor] = None,
+ input_mask: Optional[Tensor] = None,
+ context_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ residual = x
+ norm_fn, layer_fn = self._blocks[layer]
+ if layer == "self_attn":
+ out = layer_fn(x=x, input_mask=input_mask)
+ elif layer == "cross_attn":
+ out = layer_fn(
+ x=x, context=context, input_mask=input_mask, context_mask=context_mask
+ )
+ else:
+ out = layer_fn(x)
+ out += residual
+ return norm_fn(out)
+
+ def forward(
+ self,
+ x: Tensor,
+ context: Optional[Tensor] = None,
+ input_mask: Optional[Tensor] = None,
+ context_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Applies decoder block on input signals."""
+ for layer in self._layers:
+ x = self._apply(
+ layer=layer,
+ x=x,
+ context=context,
+ input_mask=input_mask,
+ context_mask=context_mask,
+ )
+ return x
+
+
+class Decoder:
+ """Decoder Network."""
+
+ def __init__(self, depth: int, block: DecoderBlock) -> None:
+ self.depth = depth
+ self.has_pos_emb: bool = block.rotary_embedding is not None
+ self._block = nn.ModuleList([deepcopy(block) for _ in range(self.depth)])
+
+ def forward(
+ self,
+ x: Tensor,
+ context: Optional[Tensor] = None,
+ input_mask: Optional[Tensor] = None,
+ context_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Applies the network to the signals."""
+ for block in self._blocks:
+ x = block(
+ x=x, context=context, input_mask=input_mask, context_mask=context_mask
+ )
+ return x