diff options
-rw-r--r-- | text_recognizer/networks/transformer/local_attention.py | 24 |
1 files changed, 17 insertions, 7 deletions
diff --git a/text_recognizer/networks/transformer/local_attention.py b/text_recognizer/networks/transformer/local_attention.py index db5bebc..134089c 100644 --- a/text_recognizer/networks/transformer/local_attention.py +++ b/text_recognizer/networks/transformer/local_attention.py @@ -20,17 +20,21 @@ from text_recognizer.networks.transformer.attention import apply_rotary_emb @attr.s(eq=False) class LocalAttention(nn.Module): + """Local windowed attention.""" + dim: int = attr.ib() + num_heads: int = attr.ib() dim_head: int = attr.ib(default=64) window_size: int = attr.ib(default=128) look_back: int = attr.ib(default=1) dropout_rate: float = attr.ib(default=0.0) def __attrs_pre_init__(self) -> None: + """Pre init constructor.""" super().__init__() def __attrs_post_init__(self) -> None: - """Post init configuration.""" + """Post init constructor.""" self.scale = self.dim ** -0.5 inner_dim = self.dim * self.dim_head @@ -49,10 +53,12 @@ class LocalAttention(nn.Module): mask: Optional[Tensor] = None, rotary_pos_emb: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: - b, n, _, device, dtype = *x.shape, x.device, x.dtype - assert ( - n % self.window_size - ), f"Sequence length {n} must be divisable with window size {self.window_size}" + """Computes windowed attention.""" + b, n, _ = x.shape + if not n % self.window_size: + RuntimeError( + f"Sequence length {n} must be divisable with window size {self.window_size}" + ) q = self.query(x) k = self.key(x) @@ -111,14 +117,16 @@ class LocalAttention(nn.Module): return out, attn -def merge_dims(ind_from, ind_to, tensor): +def merge_dims(ind_from: int, ind_to: int, tensor: Tensor) -> Tensor: + """Merge dimensions.""" shape = list(tensor.shape) arr_slice = slice(ind_from, ind_to + 1) shape[arr_slice] = [reduce(mul, shape[arr_slice])] return tensor.reshape(*shape) -def expand_dim(t, dim, k, unsqueeze=True): +def expand_dim(t: Tensor, dim: int, k: int, unsqueeze: bool = True) -> Tensor: + """Expand tensors dimensions.""" if unsqueeze: t = t.unsqueeze(dim) expand_shape = [-1] * len(t.shape) @@ -127,6 +135,7 @@ def expand_dim(t, dim, k, unsqueeze=True): def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 2) -> Tensor: + """Apply windowing.""" n = x.shape[1] dims = (len(x.shape) - dim) * (0, 0) x_pad = F.pad(x, (*dims, backward, 0), value=pad_value) @@ -143,6 +152,7 @@ def apply_input_mask( num_windows: int, mask_value: Tensor, ) -> Tensor: + """Applies input mask to energy tensor.""" h = b // mask.shape[0] mask = mask.reshape(-1, window_size, num_windows) mq = mk = mask |