1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
"""Implementes the attention module for the transformer."""
from typing import Optional
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor, einsum, nn
from text_recognizer.networks.transformer.embeddings.rotary import (
RotaryEmbedding,
)
class Attention(nn.Module):
"""Standard attention."""
def __init__(
self,
dim: int,
num_heads: int,
causal: bool = False,
dim_head: int = 64,
dropout_rate: float = 0.0,
) -> None:
super().__init__()
self.dim = dim
self.scale = self.dim**-0.5
self.num_heads = num_heads
self.dim_head = dim_head
self.causal = causal
self.dropout_rate = dropout_rate
# Single key/value head
k_dim = dim_head
v_dim = dim_head
out_dim = self.num_heads * self.dim_head
self.to_q = nn.Linear(self.dim, out_dim, bias=False)
self.to_k = nn.Linear(self.dim, k_dim, bias=False)
self.to_v = nn.Linear(self.dim, v_dim, bias=False)
self.dropout = nn.Dropout(p=self.dropout_rate)
# Feedforward
self.fc = nn.Linear(out_dim, self.dim)
def forward(
self,
x: Tensor,
context: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
rotary_embedding: Optional[RotaryEmbedding] = None,
) -> Tensor:
"""Computes the attention."""
b, device = x.shape[0], x.device
q = self.to_q(x)
q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
k = self.to_k(context) if context is not None else self.to_k(x)
v = self.to_v(context) if context is not None else self.to_v(x)
if rotary_embedding is not None:
q, k, v = map(lambda t: rotary_embedding.rotate(t), (q, k, v))
energy = einsum("b h i d, b j d -> b h i j", q, k) * self.scale
mask_value = -torch.finfo(energy.dtype).max
energy = apply_input_mask(b, k, energy, mask, mask_value, device)
if self.causal:
energy = apply_causal_mask(energy, mask, mask_value, device)
attn = F.softmax(energy, dim=-1)
attn = self.dropout(attn)
out = einsum("b h i j, b j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.fc(out)
return out
def apply_input_mask(
b: int,
k: Tensor,
energy: Tensor,
mask: Optional[Tensor],
mask_value: Tensor,
device: str,
) -> Tensor:
"""Applies an input mask."""
if mask is not None:
k_mask = torch.ones((b, k.shape[-2]), device=device).bool()
q_mask = rearrange(mask, "b i -> b () i ()")
k_mask = rearrange(k_mask, "b j -> b () () j")
input_mask = q_mask * k_mask
energy = energy.masked_fill_(~input_mask, mask_value)
return energy
def apply_causal_mask(
energy: Tensor, mask: Tensor, mask_value: Tensor, device: str
) -> Tensor:
"""Applies a causal mask to the energy tensor."""
i, j = energy.shape[-2:]
r = torch.arange(i, device=device)
mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
mask = F.pad(mask, (j - i, 0), value=False)
energy.masked_fill_(mask, mask_value)
return energy
|