1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
|
"""Implementes the attention module for the transformer."""
from typing import Optional, Tuple
from einops.layers.torch import Rearrange
import numpy as np
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from text_recognizer.networks.transformer.rotary_embedding import apply_rotary_pos_emb
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
dim_head: int = 64,
dropout_rate: float = 0.0,
causal: bool = False,
) -> None:
self.scale = dim ** -0.5
self.num_heads = num_heads
self.causal = causal
inner_dim = dim * dim_head
# Attnetion
self.qkv_fn = nn.Sequential(
nn.Linear(dim, 3 * inner_dim, bias=False),
Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=self.num_heads),
)
self.dropout = nn.Dropout(dropout_rate)
self.attn_fn = F.softmax
# Feedforward
self.proj = nn.Linear(inner_dim, dim)
@staticmethod
def _apply_rotary_emb(
q: Tensor, k: Tensor, rotary_pos_emb: Tensor
) -> Tuple[Tensor, Tensor]:
l = rotary_pos_emb.shape[-1]
(ql, qr), (kl, kr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k))
ql, kl = apply_rotary_pos_emb(ql, kl, rotary_pos_emb)
q = torch.cat((ql, qr), dim=-1)
k = torch.cat((kl, kr), dim=-1)
return q, k
def _cross_attention(self) -> Tensor:
pass
def forward(
self,
x: Tensor,
context: Optional[Tensor],
mask: Optional[Tensor],
context_mask: Optional[Tensor],
rotary_pos_emb: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
b, n, _, device = x.shape, x.device
q, k, v = self.qkv_fn(x)
q, k = (
self._apply_rotary_emb(q, k, rotary_pos_emb)
if rotary_pos_emb is not None
else q,
k,
)
input_mask = None
if any(x is not None for x in (mask, context_mask)):
q_mask = (
mask
if mask is not None
else lambda: torch.ones((b, n), device=device).bool()
)
pass
# Compute the attention
energy = (q @ k.transpose(-2, -1)) * self.scale
|