From 812ae85502f7453d457399ecd2b6e4cefae39fa7 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sat, 29 Jan 2022 17:09:37 +0100
Subject: fix(decoder): typos

---
 text_recognizer/networks/transformer/decoder.py | 26 +++++++++++++------------
 1 file changed, 14 insertions(+), 12 deletions(-)

diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py
index a58f7bd..1812e40 100644
--- a/text_recognizer/networks/transformer/decoder.py
+++ b/text_recognizer/networks/transformer/decoder.py
@@ -1,6 +1,6 @@
 """Transformer decoder module."""
 from copy import deepcopy
-from typing import Optional, Tuple, Type
+from typing import Optional, Type
 
 from torch import nn, Tensor
 
@@ -19,8 +19,8 @@ class DecoderBlock(nn.Module):
         cross_attn: Optional[Attention] = None,
     ) -> None:
         super().__init__()
-        self._layers = ("self_attn", "cross_attn", "ff")
-        self._blocks = self._build(self_attn, norm, ff, cross_attn)
+        self.layers = ("self_attn", "cross_attn", "ff")
+        self.blocks = self._build(self_attn, norm, ff, cross_attn)
 
     def _build(
         self,
@@ -37,7 +37,7 @@ class DecoderBlock(nn.Module):
             }
         )
 
-    def _apply(
+    def _apply_block(
         self,
         layer: str,
         x: Tensor,
@@ -45,8 +45,9 @@ class DecoderBlock(nn.Module):
         input_mask: Optional[Tensor] = None,
         context_mask: Optional[Tensor] = None,
     ) -> Tensor:
+        """Applies block function."""
         residual = x
-        norm_fn, layer_fn = self._blocks[layer]
+        norm_fn, layer_fn = self.blocks[layer]
         if layer == "self_attn":
             out = layer_fn(x=x, input_mask=input_mask)
         elif layer == "cross_attn":
@@ -66,8 +67,8 @@ class DecoderBlock(nn.Module):
         context_mask: Optional[Tensor] = None,
     ) -> Tensor:
         """Applies decoder block on input signals."""
-        for layer in self._layers:
-            x = self._apply(
+        for layer in self.layers:
+            x = self._apply_block(
                 layer=layer,
                 x=x,
                 context=context,
@@ -77,13 +78,14 @@ class DecoderBlock(nn.Module):
         return x
 
 
-class Decoder:
+class Decoder(nn.Module):
     """Decoder Network."""
 
-    def __init__(self, depth: int, block: DecoderBlock) -> None:
+    def __init__(self, depth: int, has_pos_emb: bool, block: DecoderBlock) -> None:
+        super().__init__()
         self.depth = depth
-        self.has_pos_emb: bool = block.rotary_embedding is not None
-        self._block = nn.ModuleList([deepcopy(block) for _ in range(self.depth)])
+        self.has_pos_emb = has_pos_emb
+        self.blocks = nn.ModuleList([deepcopy(block) for _ in range(self.depth)])
 
     def forward(
         self,
@@ -93,7 +95,7 @@ class Decoder:
         context_mask: Optional[Tensor] = None,
     ) -> Tensor:
         """Applies the network to the signals."""
-        for block in self._blocks:
+        for block in self.blocks:
             x = block(
                 x=x, context=context, input_mask=input_mask, context_mask=context_mask
             )
-- 
cgit v1.2.3-70-g09d2