1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
|
"""Transformer attention layer."""
from functools import partial
from typing import Any, Dict, Optional, Tuple
import attr
from torch import nn, Tensor
from text_recognizer.networks.transformer.residual import Residual
from text_recognizer.networks.transformer.positional_encodings.rotary_embedding import (
RotaryEmbedding,
)
from text_recognizer.networks.util import load_partial_fn
@attr.s
class AttentionLayers(nn.Module):
"""Standard transfomer layer."""
def __attrs_pre_init__(self) -> None:
super().__init__()
dim: int = attr.ib()
depth: int = attr.ib()
num_heads: int = attr.ib()
attn_fn: str = attr.ib()
attn_kwargs: Dict = attr.ib()
norm_fn: str = attr.ib()
ff_fn: str = attr.ib()
ff_kwargs: Dict = attr.ib()
causal: bool = attr.ib(default=False)
cross_attend: bool = attr.ib(default=False)
pre_norm: bool = attr.ib(default=True)
rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None, init=False)
has_pos_emb: bool = attr.ib(init=False)
layer_types: Tuple[str, ...] = attr.ib(init=False)
layers: nn.ModuleList = attr.ib(init=False)
attn: partial = attr.ib(init=False)
norm: partial = attr.ib(init=False)
ff: partial = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
self.has_pos_emb = True if self.rotary_emb is not None else False
self.layer_types = self._get_layer_types() * self.depth
attn = load_partial_fn(
self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
)
norm = load_partial_fn(self.norm_fn, dim=self.dim)
ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
self.layers = self._build_network(attn, norm, ff)
def _get_layer_types(self) -> Tuple:
"""Get layer specification."""
if self.cross_attend:
return "a", "c", "f"
return "a", "f"
def _build_network(
self, attn: partial, norm: partial, ff: partial,
) -> nn.ModuleList:
"""Configures transformer network."""
layers = nn.ModuleList([])
for layer_type in self.layer_types:
if layer_type == "a":
layer = attn(causal=self.causal)
elif layer_type == "c":
layer = attn()
elif layer_type == "f":
layer = ff()
residual_fn = Residual()
layers.append(nn.ModuleList([norm(), layer, residual_fn]))
return layers
def forward(
self,
x: Tensor,
context: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
) -> Tensor:
rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None
for i, (layer_type, (norm, block, residual_fn)) in enumerate(
zip(self.layer_types, self.layers)
):
is_last = i == len(self.layers) - 1
residual = x
if self.pre_norm:
x = norm(x)
if layer_type == "a":
out, _ = block(x=x, mask=mask, rotary_pos_emb=rotary_pos_emb)
elif layer_type == "c":
out, _ = block(x, context=context, mask=mask, context_mask=context_mask)
elif layer_type == "f":
out = block(x)
x = residual_fn(out, residual)
if not self.pre_norm and not is_last:
x = norm(x)
return x
class Encoder(AttentionLayers):
def __init__(self, **kwargs: Any) -> None:
assert "causal" not in kwargs, "Cannot set causality on encoder"
super().__init__(causal=False, **kwargs)
class Decoder(AttentionLayers):
def __init__(self, **kwargs: Any) -> None:
assert "causal" not in kwargs, "Cannot set causality on decoder"
super().__init__(causal=True, **kwargs)
|