From ffec11ce67d8fe75ea0d5dde5ddf17eb1017fa7d Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sun, 2 Oct 2022 01:45:34 +0200
Subject: Add comments

---
 text_recognizer/networks/conv_transformer.py       |  5 ++---
 text_recognizer/networks/convnext/convnext.py      |  7 +++++--
 text_recognizer/networks/convnext/downsample.py    |  4 ++++
 text_recognizer/networks/convnext/norm.py          |  3 +++
 text_recognizer/networks/convnext/residual.py      |  3 +++
 text_recognizer/networks/image_encoder.py          |  2 +-
 text_recognizer/networks/text_decoder.py           | 14 +++++--------
 .../networks/transformer/embeddings/axial.py       | 24 +++++++++++++++++-----
 8 files changed, 42 insertions(+), 20 deletions(-)

(limited to 'text_recognizer/networks')

diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index e36a786..d36162a 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -1,7 +1,6 @@
 """Base network module."""
 from typing import Type
 
-import torch
 from torch import Tensor, nn
 
 from text_recognizer.networks.transformer.decoder import Decoder
@@ -28,11 +27,11 @@ class ConvTransformer(nn.Module):
         return self.decoder(tokens, img_features)
 
     def forward(self, img: Tensor, tokens: Tensor) -> Tensor:
-        """Encodes images into word piece logtis.
+        """Encodes images into token logtis.
 
         Args:
             img (Tensor): Input image(s).
-            tokens (Tensor): Target word embeddings.
+            tokens (Tensor): token embeddings.
 
         Shapes:
             - img: :math: `(B, 1, H, W)`
diff --git a/text_recognizer/networks/convnext/convnext.py b/text_recognizer/networks/convnext/convnext.py
index b4dfad7..9419a15 100644
--- a/text_recognizer/networks/convnext/convnext.py
+++ b/text_recognizer/networks/convnext/convnext.py
@@ -1,3 +1,4 @@
+"""ConvNext module."""
 from typing import Optional, Sequence
 
 from torch import Tensor, nn
@@ -8,7 +9,9 @@ from text_recognizer.networks.convnext.norm import LayerNorm
 
 
 class ConvNextBlock(nn.Module):
-    def __init__(self, dim, dim_out, mult):
+    """ConvNext block."""
+
+    def __init__(self, dim: int, dim_out: int, mult: int) -> None:
         super().__init__()
         self.ds_conv = nn.Conv2d(
             dim, dim, kernel_size=(7, 7), padding="same", groups=dim
@@ -21,7 +24,7 @@ class ConvNextBlock(nn.Module):
         )
         self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
 
-    def forward(self, x):
+    def forward(self, x: Tensor) -> Tensor:
         h = self.ds_conv(x)
         h = self.net(h)
         return h + self.res_conv(x)
diff --git a/text_recognizer/networks/convnext/downsample.py b/text_recognizer/networks/convnext/downsample.py
index c28ecca..a8a0466 100644
--- a/text_recognizer/networks/convnext/downsample.py
+++ b/text_recognizer/networks/convnext/downsample.py
@@ -1,3 +1,4 @@
+"""Convnext downsample module."""
 from typing import Tuple
 
 from einops.layers.torch import Rearrange
@@ -5,6 +6,8 @@ from torch import Tensor, nn
 
 
 class Downsample(nn.Module):
+    """Downsamples feature maps by patches."""
+
     def __init__(self, dim: int, dim_out: int, factors: Tuple[int, int]) -> None:
         super().__init__()
         s1, s2 = factors
@@ -14,4 +17,5 @@ class Downsample(nn.Module):
         )
 
     def forward(self, x: Tensor) -> Tensor:
+        """Applies patch function."""
         return self.fn(x)
diff --git a/text_recognizer/networks/convnext/norm.py b/text_recognizer/networks/convnext/norm.py
index 23cf07a..3355de9 100644
--- a/text_recognizer/networks/convnext/norm.py
+++ b/text_recognizer/networks/convnext/norm.py
@@ -4,11 +4,14 @@ from torch import Tensor, nn
 
 
 class LayerNorm(nn.Module):
+    """Layer norm for convolutions."""
+
     def __init__(self, dim: int) -> None:
         super().__init__()
         self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
 
     def forward(self, x: Tensor) -> Tensor:
+        """Applies layer norm."""
         eps = 1e-5 if x.dtype == torch.float32 else 1e-3
         var = torch.var(x, dim=1, unbiased=False, keepdim=True)
         mean = torch.mean(x, dim=1, keepdim=True)
diff --git a/text_recognizer/networks/convnext/residual.py b/text_recognizer/networks/convnext/residual.py
index 8e76ae9..dfc2847 100644
--- a/text_recognizer/networks/convnext/residual.py
+++ b/text_recognizer/networks/convnext/residual.py
@@ -5,9 +5,12 @@ from torch import Tensor, nn
 
 
 class Residual(nn.Module):
+    """Residual layer."""
+
     def __init__(self, fn: Callable) -> None:
         super().__init__()
         self.fn = fn
 
     def forward(self, x: Tensor) -> Tensor:
+        """Applies residual fn."""
         return self.fn(x) + x
diff --git a/text_recognizer/networks/image_encoder.py b/text_recognizer/networks/image_encoder.py
index b5fd0c5..ab60560 100644
--- a/text_recognizer/networks/image_encoder.py
+++ b/text_recognizer/networks/image_encoder.py
@@ -9,7 +9,7 @@ from text_recognizer.networks.transformer.embeddings.axial import (
 
 
 class ImageEncoder(nn.Module):
-    """Base transformer network."""
+    """Encodes images to latent embeddings."""
 
     def __init__(
         self,
diff --git a/text_recognizer/networks/text_decoder.py b/text_recognizer/networks/text_decoder.py
index 7ee6720..7498663 100644
--- a/text_recognizer/networks/text_decoder.py
+++ b/text_recognizer/networks/text_decoder.py
@@ -1,6 +1,4 @@
 """Text decoder."""
-from typing import Optional, Type
-
 import torch
 from torch import Tensor, nn
 
@@ -8,26 +6,24 @@ from text_recognizer.networks.transformer.decoder import Decoder
 
 
 class TextDecoder(nn.Module):
-    """Decoder transformer network."""
+    """Decodes images to token logits."""
 
     def __init__(
         self,
-        hidden_dim: int,
+        dim: int,
         num_classes: int,
         pad_index: Tensor,
         decoder: Decoder,
     ) -> None:
         super().__init__()
-        self.hidden_dim = hidden_dim
+        self.dim = dim
         self.num_classes = num_classes
         self.pad_index = pad_index
         self.decoder = decoder
         self.token_embedding = nn.Embedding(
-            num_embeddings=self.num_classes, embedding_dim=self.hidden_dim
-        )
-        self.to_logits = nn.Linear(
-            in_features=self.hidden_dim, out_features=self.num_classes
+            num_embeddings=self.num_classes, embedding_dim=self.dim
         )
+        self.to_logits = nn.Linear(in_features=self.dim, out_features=self.num_classes)
 
     def forward(self, tokens: Tensor, img_features: Tensor) -> Tensor:
         """Decodes latent images embedding into word pieces.
diff --git a/text_recognizer/networks/transformer/embeddings/axial.py b/text_recognizer/networks/transformer/embeddings/axial.py
index 25d8f60..9b872a9 100644
--- a/text_recognizer/networks/transformer/embeddings/axial.py
+++ b/text_recognizer/networks/transformer/embeddings/axial.py
@@ -1,17 +1,24 @@
 """Axial attention for multi-dimensional data.
 
 Stolen from:
-    https://github.com/lucidrains/axial-attention/blob/eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100
+    https://github.com/lucidrains/axial-attention/blob/
+    eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100
 """
 from functools import reduce
 from operator import mul
+from typing import Optional, Sequence
 
 import torch
-from torch import nn
+from torch import Tensor, nn
 
 
 class AxialPositionalEmbedding(nn.Module):
-    def __init__(self, dim, axial_shape, axial_dims=None):
+    def __init__(
+        self,
+        dim: int,
+        axial_shape: Sequence[int],
+        axial_dims: Optional[Sequence[int]] = None,
+    ) -> None:
         super().__init__()
 
         self.dim = dim
@@ -37,7 +44,8 @@ class AxialPositionalEmbedding(nn.Module):
             ax_emb = nn.Parameter(torch.zeros(ax_shape).normal_(0, 1))
             self.weights.append(ax_emb)
 
-    def forward(self, x):
+    def forward(self, x: Tensor) -> Tensor:
+        """Returns axial positional embedding."""
         b, t, _ = x.shape
         assert (
             t <= self.max_seq_len
@@ -77,8 +85,14 @@ class ParameterList(object):
 
 
 class AxialPositionalEmbeddingImage(nn.Module):
-    def __init__(self, dim, axial_shape, axial_dims=None):
+    def __init__(
+        self,
+        dim: int,
+        axial_shape: Sequence[int],
+        axial_dims: Optional[Sequence[int]] = None,
+    ) -> None:
         super().__init__()
+        axial_dims = (dim // 2, dim // 2) if axial_dims is None else axial_dims
         assert len(axial_shape) == 2, "Axial shape must have 2 dimensions for images"
         self.dim = dim
         self.pos_emb = AxialPositionalEmbedding(dim, axial_shape, axial_dims)
-- 
cgit v1.2.3-70-g09d2