1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
|
"""Implementes the attention module for the transformer."""
from typing import Optional, Tuple
from einops import rearrange
import numpy as np
import torch
from torch import nn
from torch import Tensor
class MultiHeadAttention(nn.Module):
"""Implementation of multihead attention."""
def __init__(
self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0
) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.fc_q = nn.Linear(
in_features=hidden_dim, out_features=hidden_dim, bias=False
)
self.fc_k = nn.Linear(
in_features=hidden_dim, out_features=hidden_dim, bias=False
)
self.fc_v = nn.Linear(
in_features=hidden_dim, out_features=hidden_dim, bias=False
)
self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
self._init_weights()
self.dropout = nn.Dropout(p=dropout_rate)
def _init_weights(self) -> None:
nn.init.normal_(
self.fc_q.weight,
mean=0,
std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
)
nn.init.normal_(
self.fc_k.weight,
mean=0,
std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
)
nn.init.normal_(
self.fc_v.weight,
mean=0,
std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
)
nn.init.xavier_normal_(self.fc_out.weight)
@staticmethod
def scaled_dot_product_attention(
query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
) -> Tensor:
"""Calculates the scaled dot product attention."""
# Compute the energy.
energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt(
query.shape[-1]
)
# If we have a mask for padding some inputs.
if mask is not None:
energy = energy.masked_fill(mask == 0, -np.inf)
# Compute the attention from the energy.
attention = torch.softmax(energy, dim=3)
out = torch.einsum("bhlt,bhtv->bhlv", [attention, value])
out = rearrange(out, "b head l v -> b l (head v)")
return out, attention
def forward(
self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor]:
"""Forward pass for computing the multihead attention."""
# Get the query, key, and value tensor.
query = rearrange(
self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads
)
key = rearrange(
self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads
)
value = rearrange(
self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads
)
out, attention = self.scaled_dot_product_attention(query, key, value, mask)
out = self.fc_out(out)
out = self.dropout(out)
return out, attention
|