1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
|
"""Generates the attention layer architecture."""
from functools import partial
from typing import Dict, Optional, Type
from click.types import Tuple
import torch
from torch import nn, Tensor
from .attention import Attention
from .mlp import FeedForward
from .residual import Residual
from .rotary_embedding import RotaryEmbedding
class AttentionLayers(nn.Module):
def __init__(
self,
dim: int,
depth: int,
num_heads: int,
ff_kwargs: Dict,
attn_kwargs: Dict,
attn_fn: Type[nn.Module] = Attention,
norm_fn: Type[nn.Module] = nn.LayerNorm,
ff_fn: Type[nn.Module] = FeedForward,
residual_fn: Type[nn.Module] = Residual,
rotary_emb: Optional[Type[nn.Module]] = None,
rotary_emb_dim: Optional[int] = None,
causal: bool = False,
cross_attend: bool = False,
pre_norm: bool = True,
) -> None:
super().__init__()
attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs)
norm_fn = partial(norm_fn, dim=dim)
ff_fn = partial(ff_fn, dim=dim, **ff_kwargs)
self.layer_types = self._get_layer_types(cross_attend) * depth
self.layers = self._build_network(
causal, attn_fn, norm_fn, ff_fn, residual_fn
)
rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None
self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None
self.pre_norm = pre_norm
@staticmethod
def _get_layer_types(cross_attend: bool) -> Tuple:
"""Get layer specification."""
if cross_attend:
return "a", "c", "f"
return "a", "f"
def _build_network(
self,
causal: bool,
attn_fn: partial,
norm_fn: partial,
ff_fn: partial,
residual_fn: Type[nn.Module],
) -> nn.ModuleList:
"""Configures transformer network."""
layers = nn.ModuleList([])
for layer_type in self.layer_types:
if layer_type == "a":
layer = attn_fn(causal=causal)
elif layer_type == "c":
layer = attn_fn()
elif layer_type == "f":
layer = ff_fn()
residual_fn = residual_fn()
layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
return layers
def forward(
self,
x: Tensor,
context: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
) -> Tensor:
rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None
for i, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
is_last = i == len(self.layers) - 1
residual = x
if self.pre_norm:
x = norm(x)
if layer_type == "a":
out, _ = block(x=x, mask=mask, rotary_pos_emb=rotary_pos_emb)
elif layer_type == "c":
out, _ = block(x, context=context, mask=mask, context_mask=context_mask)
elif layer_type == "f":
out = block(x)
x = residual_fn(out, residual)
if not self.pre_norm and not is_last:
x = norm(x)
return x
|