From 8a7d47e9a432ec927993cc546dacb89a97a05cda Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Wed, 27 Oct 2021 22:15:43 +0200
Subject: Clean up local attention, add comments and types

---
 .../networks/transformer/local_attention.py        | 24 +++++++++++++++-------
 1 file changed, 17 insertions(+), 7 deletions(-)

diff --git a/text_recognizer/networks/transformer/local_attention.py b/text_recognizer/networks/transformer/local_attention.py
index db5bebc..134089c 100644
--- a/text_recognizer/networks/transformer/local_attention.py
+++ b/text_recognizer/networks/transformer/local_attention.py
@@ -20,17 +20,21 @@ from text_recognizer.networks.transformer.attention import apply_rotary_emb
 
 @attr.s(eq=False)
 class LocalAttention(nn.Module):
+    """Local windowed attention."""
+
     dim: int = attr.ib()
+    num_heads: int = attr.ib()
     dim_head: int = attr.ib(default=64)
     window_size: int = attr.ib(default=128)
     look_back: int = attr.ib(default=1)
     dropout_rate: float = attr.ib(default=0.0)
 
     def __attrs_pre_init__(self) -> None:
+        """Pre init constructor."""
         super().__init__()
 
     def __attrs_post_init__(self) -> None:
-        """Post init configuration."""
+        """Post init constructor."""
         self.scale = self.dim ** -0.5
         inner_dim = self.dim * self.dim_head
 
@@ -49,10 +53,12 @@ class LocalAttention(nn.Module):
         mask: Optional[Tensor] = None,
         rotary_pos_emb: Optional[Tensor] = None,
     ) -> Tuple[Tensor, Tensor]:
-        b, n, _, device, dtype = *x.shape, x.device, x.dtype
-        assert (
-            n % self.window_size
-        ), f"Sequence length {n} must be divisable with window size {self.window_size}"
+        """Computes windowed attention."""
+        b, n, _ = x.shape
+        if not n % self.window_size:
+            RuntimeError(
+                f"Sequence length {n} must be divisable with window size {self.window_size}"
+            )
 
         q = self.query(x)
         k = self.key(x)
@@ -111,14 +117,16 @@ class LocalAttention(nn.Module):
         return out, attn
 
 
-def merge_dims(ind_from, ind_to, tensor):
+def merge_dims(ind_from: int, ind_to: int, tensor: Tensor) -> Tensor:
+    """Merge dimensions."""
     shape = list(tensor.shape)
     arr_slice = slice(ind_from, ind_to + 1)
     shape[arr_slice] = [reduce(mul, shape[arr_slice])]
     return tensor.reshape(*shape)
 
 
-def expand_dim(t, dim, k, unsqueeze=True):
+def expand_dim(t: Tensor, dim: int, k: int, unsqueeze: bool = True) -> Tensor:
+    """Expand tensors dimensions."""
     if unsqueeze:
         t = t.unsqueeze(dim)
     expand_shape = [-1] * len(t.shape)
@@ -127,6 +135,7 @@ def expand_dim(t, dim, k, unsqueeze=True):
 
 
 def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 2) -> Tensor:
+    """Apply windowing."""
     n = x.shape[1]
     dims = (len(x.shape) - dim) * (0, 0)
     x_pad = F.pad(x, (*dims, backward, 0), value=pad_value)
@@ -143,6 +152,7 @@ def apply_input_mask(
     num_windows: int,
     mask_value: Tensor,
 ) -> Tensor:
+    """Applies input mask to energy tensor."""
     h = b // mask.shape[0]
     mask = mask.reshape(-1, window_size, num_windows)
     mq = mk = mask
-- 
cgit v1.2.3-70-g09d2