diff options
Diffstat (limited to 'text_recognizer')
| -rw-r--r-- | text_recognizer/networks/transformer/local_attention.py | 24 | 
1 files changed, 17 insertions, 7 deletions
diff --git a/text_recognizer/networks/transformer/local_attention.py b/text_recognizer/networks/transformer/local_attention.py index db5bebc..134089c 100644 --- a/text_recognizer/networks/transformer/local_attention.py +++ b/text_recognizer/networks/transformer/local_attention.py @@ -20,17 +20,21 @@ from text_recognizer.networks.transformer.attention import apply_rotary_emb  @attr.s(eq=False)  class LocalAttention(nn.Module): +    """Local windowed attention.""" +      dim: int = attr.ib() +    num_heads: int = attr.ib()      dim_head: int = attr.ib(default=64)      window_size: int = attr.ib(default=128)      look_back: int = attr.ib(default=1)      dropout_rate: float = attr.ib(default=0.0)      def __attrs_pre_init__(self) -> None: +        """Pre init constructor."""          super().__init__()      def __attrs_post_init__(self) -> None: -        """Post init configuration.""" +        """Post init constructor."""          self.scale = self.dim ** -0.5          inner_dim = self.dim * self.dim_head @@ -49,10 +53,12 @@ class LocalAttention(nn.Module):          mask: Optional[Tensor] = None,          rotary_pos_emb: Optional[Tensor] = None,      ) -> Tuple[Tensor, Tensor]: -        b, n, _, device, dtype = *x.shape, x.device, x.dtype -        assert ( -            n % self.window_size -        ), f"Sequence length {n} must be divisable with window size {self.window_size}" +        """Computes windowed attention.""" +        b, n, _ = x.shape +        if not n % self.window_size: +            RuntimeError( +                f"Sequence length {n} must be divisable with window size {self.window_size}" +            )          q = self.query(x)          k = self.key(x) @@ -111,14 +117,16 @@ class LocalAttention(nn.Module):          return out, attn -def merge_dims(ind_from, ind_to, tensor): +def merge_dims(ind_from: int, ind_to: int, tensor: Tensor) -> Tensor: +    """Merge dimensions."""      shape = list(tensor.shape)      arr_slice = slice(ind_from, ind_to + 1)      shape[arr_slice] = [reduce(mul, shape[arr_slice])]      return tensor.reshape(*shape) -def expand_dim(t, dim, k, unsqueeze=True): +def expand_dim(t: Tensor, dim: int, k: int, unsqueeze: bool = True) -> Tensor: +    """Expand tensors dimensions."""      if unsqueeze:          t = t.unsqueeze(dim)      expand_shape = [-1] * len(t.shape) @@ -127,6 +135,7 @@ def expand_dim(t, dim, k, unsqueeze=True):  def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 2) -> Tensor: +    """Apply windowing."""      n = x.shape[1]      dims = (len(x.shape) - dim) * (0, 0)      x_pad = F.pad(x, (*dims, backward, 0), value=pad_value) @@ -143,6 +152,7 @@ def apply_input_mask(      num_windows: int,      mask_value: Tensor,  ) -> Tensor: +    """Applies input mask to energy tensor."""      h = b // mask.shape[0]      mask = mask.reshape(-1, window_size, num_windows)      mq = mk = mask  |