From bf680dce6bc7dcadd20923a193fc9ab8fbd0a0c6 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Tue, 27 Sep 2022 00:10:00 +0200
Subject: Refactor decoder block

---
 text_recognizer/networks/transformer/decoder.py    | 77 ++--------------------
 .../networks/transformer/decoder_block.py          | 46 +++++++++++++
 2 files changed, 50 insertions(+), 73 deletions(-)
 create mode 100644 text_recognizer/networks/transformer/decoder_block.py

diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py
index efa1e89..c7da226 100644
--- a/text_recognizer/networks/transformer/decoder.py
+++ b/text_recognizer/networks/transformer/decoder.py
@@ -5,88 +5,19 @@ from typing import Optional, Type
 from torch import Tensor, nn
 
 from text_recognizer.networks.transformer.attention import Attention
+from text_recognizer.networks.transformer.decoder_block import DecoderBlock
 from text_recognizer.networks.transformer.ff import FeedForward
 
 
-class DecoderBlock(nn.Module):
-    """Decoder block."""
-
-    def __init__(
-        self,
-        self_attn: Attention,
-        norm: Type[nn.Module],
-        ff: FeedForward,
-        cross_attn: Optional[Attention] = None,
-    ) -> None:
-        super().__init__()
-        self.layers = ("self_attn", "cross_attn", "ff")
-        self.has_pos_emb = self_attn.rotary_embedding is not None
-        self.blocks = self._build(self_attn, norm, ff, cross_attn)
-
-    def _build(
-        self,
-        self_attn: Attention,
-        norm: Type[nn.Module],
-        ff: FeedForward,
-        cross_attn: Optional[Attention],
-    ) -> nn.ModuleDict:
-        return nn.ModuleDict(
-            {
-                self.layers[0]: nn.ModuleList([norm, self_attn]),
-                self.layers[1]: nn.ModuleList([deepcopy(norm), cross_attn]),
-                self.layers[2]: nn.ModuleList([deepcopy(norm), ff]),
-            }
-        )
-
-    def _apply_block(
-        self,
-        layer: str,
-        x: Tensor,
-        context: Optional[Tensor] = None,
-        input_mask: Optional[Tensor] = None,
-        context_mask: Optional[Tensor] = None,
-    ) -> Tensor:
-        """Applies block function."""
-        residual = x
-        norm_fn, layer_fn = self.blocks[layer]
-        if layer == "self_attn":
-            out = layer_fn(x=x, input_mask=input_mask)
-        elif layer == "cross_attn":
-            out = layer_fn(
-                x=x, context=context, input_mask=input_mask, context_mask=context_mask
-            )
-        else:
-            out = layer_fn(x)
-        out += residual
-        return norm_fn(out)
-
-    def forward(
-        self,
-        x: Tensor,
-        context: Optional[Tensor] = None,
-        input_mask: Optional[Tensor] = None,
-        context_mask: Optional[Tensor] = None,
-    ) -> Tensor:
-        """Applies decoder block on input signals."""
-        for layer in self.layers:
-            x = self._apply_block(
-                layer=layer,
-                x=x,
-                context=context,
-                input_mask=input_mask,
-                context_mask=context_mask,
-            )
-        return x
-
-
 class Decoder(nn.Module):
     """Decoder Network."""
 
-    def __init__(self, depth: int, block: DecoderBlock) -> None:
+    def __init__(self, depth: int, dim: int, block: DecoderBlock) -> None:
         super().__init__()
         self.depth = depth
         self.has_pos_emb = block.has_pos_emb
         self.blocks = nn.ModuleList([deepcopy(block) for _ in range(self.depth)])
+        self.ln = nn.LayerNorm(dim)
 
     def forward(
         self,
@@ -100,4 +31,4 @@ class Decoder(nn.Module):
             x = block(
                 x=x, context=context, input_mask=input_mask, context_mask=context_mask
             )
-        return x
+        return self.ln(x)
diff --git a/text_recognizer/networks/transformer/decoder_block.py b/text_recognizer/networks/transformer/decoder_block.py
new file mode 100644
index 0000000..e6e7fb8
--- /dev/null
+++ b/text_recognizer/networks/transformer/decoder_block.py
@@ -0,0 +1,46 @@
+"""Transformer decoder module."""
+from copy import deepcopy
+from typing import Optional, Type
+
+from torch import Tensor, nn
+
+from text_recognizer.networks.transformer.attention import Attention
+from text_recognizer.networks.transformer.ff import FeedForward
+
+
+class DecoderBlock(nn.Module):
+    """Decoder block."""
+
+    def __init__(
+        self,
+        self_attn: Attention,
+        norm: Type[nn.Module],
+        ff: FeedForward,
+        cross_attn: Optional[Attention] = None,
+    ) -> None:
+        super().__init__()
+        self.ln_attn = norm
+        self.attn = self_attn
+        self.ln_cross_attn = deepcopy(norm)
+        self.cross_attn = cross_attn
+        self.ln_ff = deepcopy(norm)
+        self.ff = ff
+        self.has_pos_emb = self.attn.rotary_embedding is not None
+
+    def forward(
+        self,
+        x: Tensor,
+        context: Optional[Tensor] = None,
+        input_mask: Optional[Tensor] = None,
+        context_mask: Optional[Tensor] = None,
+    ) -> Tensor:
+        """Applies decoder block on input signals."""
+        x = x + self.attn(self.ln_attn(x), input_mask=input_mask)
+        x += self.cross_attn(
+            x=self.ln_cross_attn(x),
+            context=context,
+            input_mask=input_mask,
+            context_mask=context_mask,
+        )
+        x += self.ff(self.ln_ff(x))
+        return x
-- 
cgit v1.2.3-70-g09d2