From 810d8b2403dd0a229063c5693deac694871243f6 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Wed, 27 Oct 2021 22:16:04 +0200
Subject: Add comments to transformer modules

---
 text_recognizer/networks/transformer/norm.py     | 5 +++--
 text_recognizer/networks/transformer/residual.py | 2 ++
 2 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py
index 4930adf..c59744a 100644
--- a/text_recognizer/networks/transformer/norm.py
+++ b/text_recognizer/networks/transformer/norm.py
@@ -4,7 +4,7 @@ Copied from lucidrains:
     https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
 
 """
-from typing import Callable, Dict, Type
+from typing import Dict, Type
 
 import torch
 from torch import nn
@@ -19,6 +19,7 @@ class ScaleNorm(nn.Module):
         self.g = nn.Parameter(torch.ones(1))
 
     def forward(self, x: Tensor) -> Tensor:
+        """Applies scale norm."""
         norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
         return x / norm.clamp(min=self.eps) * self.g
 
@@ -30,6 +31,6 @@ class PreNorm(nn.Module):
         self.fn = fn
 
     def forward(self, x: Tensor, **kwargs: Dict) -> Tensor:
-        """Norm tensor."""
+        """Applies pre norm."""
         x = self.norm(x)
         return self.fn(x, **kwargs)
diff --git a/text_recognizer/networks/transformer/residual.py b/text_recognizer/networks/transformer/residual.py
index 1547df6..825a0fc 100644
--- a/text_recognizer/networks/transformer/residual.py
+++ b/text_recognizer/networks/transformer/residual.py
@@ -3,6 +3,8 @@ from torch import nn, Tensor
 
 
 class Residual(nn.Module):
+    """Residual block."""
+
     def forward(self, x: Tensor, residual: Tensor) -> Tensor:
         """Applies the residual function."""
         return x + residual
-- 
cgit v1.2.3-70-g09d2