diff options
Diffstat (limited to 'text_recognizer/networks')
| -rw-r--r-- | text_recognizer/networks/transformer/local_attention.py | 44 | 
1 files changed, 26 insertions, 18 deletions
diff --git a/text_recognizer/networks/transformer/local_attention.py b/text_recognizer/networks/transformer/local_attention.py index 134089c..002069c 100644 --- a/text_recognizer/networks/transformer/local_attention.py +++ b/text_recognizer/networks/transformer/local_attention.py @@ -54,7 +54,7 @@ class LocalAttention(nn.Module):          rotary_pos_emb: Optional[Tensor] = None,      ) -> Tuple[Tensor, Tensor]:          """Computes windowed attention.""" -        b, n, _ = x.shape +        b, n, d = x.shape          if not n % self.window_size:              RuntimeError(                  f"Sequence length {n} must be divisable with window size {self.window_size}" @@ -75,25 +75,30 @@ class LocalAttention(nn.Module):          num_windows = n // self.window_size          # Compute buckets -        b_n = torch.arange(n).type_as(q).reshape(1, num_windows, self.window_size) +        b_n = ( +            torch.arange(self.num_heads * n) +            .type_as(q) +            .reshape(1, self.num_heads, num_windows, self.window_size) +        )          bq, bk, bv = map( -            lambda t: t.reshape(b, num_windows, self.window_size, -1), (q, k, v) +            lambda t: t.reshape(b, self.num_heads, num_windows, self.window_size, -1), +            (q, k, v),          ) -        bk = look_around(bk, backward=self.backward) -        bv = look_around(bv, backward=self.backward) -        bq_k = look_around(b_n, backward=self.backward) +        bk = look_around(bk, backward=self.look_back) +        bv = look_around(bv, backward=self.look_back) +        bq_k = look_around(b_n, backward=self.look_back)          # Compute the attention. -        energy = einsum("b h i d, b h j d -> b h i j", bq, bk) * self.scale +        energy = einsum("b h n i d, b h n j d -> b h n i j", bq, bk) * self.scale          mask_value = -torch.finfo(energy.dtype).max          # Causal mask. -        causal_mask = b_n[:, :, :, None] < bq_k[:, :, None, :] +        causal_mask = b_n[:, :, :, :, None] < bq_k[:, :, :, None, :]          energy = energy.masked_fill_(causal_mask, mask_value)          del causal_mask -        bucket_mask = bq_k[:, :, None, :] == -1 +        bucket_mask = bq_k[:, :, :, None, :] == -1          energy.masked_fill_(bucket_mask, mask_value)          del bucket_mask @@ -101,17 +106,18 @@ class LocalAttention(nn.Module):              b,              energy=energy,              mask=mask, -            backward=self.backward, +            backward=self.look_back,              window_size=self.window_size,              num_windows=num_windows, +            num_heads=self.num_heads,              mask_value=mask_value,          )          attn = F.softmax(energy, dim=-1)          attn = self.dropout(attn) -        out = einsum("b h i j, b h j d -> b h i d", attn, bv) -        out = rearrange(out, "b h n d -> b n (h d)") +        out = einsum("b h n i j, b h n j d -> b h n i d", attn, bv) +        out = out.reshape(-1, n, d)          out = self.fc(out)          return out, attn @@ -134,12 +140,12 @@ def expand_dim(t: Tensor, dim: int, k: int, unsqueeze: bool = True) -> Tensor:      return t.expand(*expand_shape) -def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 2) -> Tensor: +def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 3) -> Tensor:      """Apply windowing.""" -    n = x.shape[1] +    n = x.shape[2]      dims = (len(x.shape) - dim) * (0, 0)      x_pad = F.pad(x, (*dims, backward, 0), value=pad_value) -    tensors = [x_pad[:, ind : (ind + n), ...] for ind in range(backward + 1)] +    tensors = [x_pad[:, :, ind : (ind + n), ...] for ind in range(backward + 1)]      return torch.cat(tensors, dim=dim) @@ -150,15 +156,17 @@ def apply_input_mask(      backward: int,      window_size: int,      num_windows: int, +    num_heads: int,      mask_value: Tensor,  ) -> Tensor:      """Applies input mask to energy tensor."""      h = b // mask.shape[0] -    mask = mask.reshape(-1, window_size, num_windows) +    mask = torch.cat([mask] * num_heads) +    mask = mask.reshape(-1, num_heads, num_windows, window_size)      mq = mk = mask      mk = look_around(mk, pad_value=False, backward=backward) -    mask = mq[:, :, :, None] * mk[:, :, None, :] -    mask = merge_dims(0, 1, expand_dim(mask, 1, h)) +    mask = mq[:, :, :, :, None] * mk[:, :, :, None, :] +    mask = merge_dims(1, 2, expand_dim(mask, 2, h))      energy.masked_fill_(~mask, mask_value)      del mask      return energy  |