From ffec11ce67d8fe75ea0d5dde5ddf17eb1017fa7d Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sun, 2 Oct 2022 01:45:34 +0200
Subject: Add comments

---
 text_recognizer/networks/convnext/convnext.py   | 7 +++++--
 text_recognizer/networks/convnext/downsample.py | 4 ++++
 text_recognizer/networks/convnext/norm.py       | 3 +++
 text_recognizer/networks/convnext/residual.py   | 3 +++
 4 files changed, 15 insertions(+), 2 deletions(-)

(limited to 'text_recognizer/networks/convnext')

diff --git a/text_recognizer/networks/convnext/convnext.py b/text_recognizer/networks/convnext/convnext.py
index b4dfad7..9419a15 100644
--- a/text_recognizer/networks/convnext/convnext.py
+++ b/text_recognizer/networks/convnext/convnext.py
@@ -1,3 +1,4 @@
+"""ConvNext module."""
 from typing import Optional, Sequence
 
 from torch import Tensor, nn
@@ -8,7 +9,9 @@ from text_recognizer.networks.convnext.norm import LayerNorm
 
 
 class ConvNextBlock(nn.Module):
-    def __init__(self, dim, dim_out, mult):
+    """ConvNext block."""
+
+    def __init__(self, dim: int, dim_out: int, mult: int) -> None:
         super().__init__()
         self.ds_conv = nn.Conv2d(
             dim, dim, kernel_size=(7, 7), padding="same", groups=dim
@@ -21,7 +24,7 @@ class ConvNextBlock(nn.Module):
         )
         self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
 
-    def forward(self, x):
+    def forward(self, x: Tensor) -> Tensor:
         h = self.ds_conv(x)
         h = self.net(h)
         return h + self.res_conv(x)
diff --git a/text_recognizer/networks/convnext/downsample.py b/text_recognizer/networks/convnext/downsample.py
index c28ecca..a8a0466 100644
--- a/text_recognizer/networks/convnext/downsample.py
+++ b/text_recognizer/networks/convnext/downsample.py
@@ -1,3 +1,4 @@
+"""Convnext downsample module."""
 from typing import Tuple
 
 from einops.layers.torch import Rearrange
@@ -5,6 +6,8 @@ from torch import Tensor, nn
 
 
 class Downsample(nn.Module):
+    """Downsamples feature maps by patches."""
+
     def __init__(self, dim: int, dim_out: int, factors: Tuple[int, int]) -> None:
         super().__init__()
         s1, s2 = factors
@@ -14,4 +17,5 @@ class Downsample(nn.Module):
         )
 
     def forward(self, x: Tensor) -> Tensor:
+        """Applies patch function."""
         return self.fn(x)
diff --git a/text_recognizer/networks/convnext/norm.py b/text_recognizer/networks/convnext/norm.py
index 23cf07a..3355de9 100644
--- a/text_recognizer/networks/convnext/norm.py
+++ b/text_recognizer/networks/convnext/norm.py
@@ -4,11 +4,14 @@ from torch import Tensor, nn
 
 
 class LayerNorm(nn.Module):
+    """Layer norm for convolutions."""
+
     def __init__(self, dim: int) -> None:
         super().__init__()
         self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
 
     def forward(self, x: Tensor) -> Tensor:
+        """Applies layer norm."""
         eps = 1e-5 if x.dtype == torch.float32 else 1e-3
         var = torch.var(x, dim=1, unbiased=False, keepdim=True)
         mean = torch.mean(x, dim=1, keepdim=True)
diff --git a/text_recognizer/networks/convnext/residual.py b/text_recognizer/networks/convnext/residual.py
index 8e76ae9..dfc2847 100644
--- a/text_recognizer/networks/convnext/residual.py
+++ b/text_recognizer/networks/convnext/residual.py
@@ -5,9 +5,12 @@ from torch import Tensor, nn
 
 
 class Residual(nn.Module):
+    """Residual layer."""
+
     def __init__(self, fn: Callable) -> None:
         super().__init__()
         self.fn = fn
 
     def forward(self, x: Tensor) -> Tensor:
+        """Applies residual fn."""
         return self.fn(x) + x
-- 
cgit v1.2.3-70-g09d2