summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/transformer/local_attention.py24
1 files changed, 17 insertions, 7 deletions
diff --git a/text_recognizer/networks/transformer/local_attention.py b/text_recognizer/networks/transformer/local_attention.py
index db5bebc..134089c 100644
--- a/text_recognizer/networks/transformer/local_attention.py
+++ b/text_recognizer/networks/transformer/local_attention.py
@@ -20,17 +20,21 @@ from text_recognizer.networks.transformer.attention import apply_rotary_emb
@attr.s(eq=False)
class LocalAttention(nn.Module):
+ """Local windowed attention."""
+
dim: int = attr.ib()
+ num_heads: int = attr.ib()
dim_head: int = attr.ib(default=64)
window_size: int = attr.ib(default=128)
look_back: int = attr.ib(default=1)
dropout_rate: float = attr.ib(default=0.0)
def __attrs_pre_init__(self) -> None:
+ """Pre init constructor."""
super().__init__()
def __attrs_post_init__(self) -> None:
- """Post init configuration."""
+ """Post init constructor."""
self.scale = self.dim ** -0.5
inner_dim = self.dim * self.dim_head
@@ -49,10 +53,12 @@ class LocalAttention(nn.Module):
mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
- b, n, _, device, dtype = *x.shape, x.device, x.dtype
- assert (
- n % self.window_size
- ), f"Sequence length {n} must be divisable with window size {self.window_size}"
+ """Computes windowed attention."""
+ b, n, _ = x.shape
+ if not n % self.window_size:
+ RuntimeError(
+ f"Sequence length {n} must be divisable with window size {self.window_size}"
+ )
q = self.query(x)
k = self.key(x)
@@ -111,14 +117,16 @@ class LocalAttention(nn.Module):
return out, attn
-def merge_dims(ind_from, ind_to, tensor):
+def merge_dims(ind_from: int, ind_to: int, tensor: Tensor) -> Tensor:
+ """Merge dimensions."""
shape = list(tensor.shape)
arr_slice = slice(ind_from, ind_to + 1)
shape[arr_slice] = [reduce(mul, shape[arr_slice])]
return tensor.reshape(*shape)
-def expand_dim(t, dim, k, unsqueeze=True):
+def expand_dim(t: Tensor, dim: int, k: int, unsqueeze: bool = True) -> Tensor:
+ """Expand tensors dimensions."""
if unsqueeze:
t = t.unsqueeze(dim)
expand_shape = [-1] * len(t.shape)
@@ -127,6 +135,7 @@ def expand_dim(t, dim, k, unsqueeze=True):
def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 2) -> Tensor:
+ """Apply windowing."""
n = x.shape[1]
dims = (len(x.shape) - dim) * (0, 0)
x_pad = F.pad(x, (*dims, backward, 0), value=pad_value)
@@ -143,6 +152,7 @@ def apply_input_mask(
num_windows: int,
mask_value: Tensor,
) -> Tensor:
+ """Applies input mask to energy tensor."""
h = b // mask.shape[0]
mask = mask.reshape(-1, window_size, num_windows)
mq = mk = mask