summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/layers.py
blob: b2c703fbdd528d280f305e252eefcb0c403588c8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Generates the attention layer architecture."""
from functools import partial
from typing import Any, Dict, Optional, Type

from click.types import Tuple

from torch import nn, Tensor

from .attention import Attention
from .mlp import FeedForward
from .residual import Residual
from .positional_encodings.rotary_embedding import RotaryEmbedding


class AttentionLayers(nn.Module):
    def __init__(
        self,
        dim: int,
        depth: int,
        num_heads: int,
        ff_kwargs: Dict,
        attn_kwargs: Dict,
        attn_fn: Type[nn.Module] = Attention,
        norm_fn: Type[nn.Module] = nn.LayerNorm,
        ff_fn: Type[nn.Module] = FeedForward,
        rotary_emb: Optional[Type[nn.Module]] = None,
        rotary_emb_dim: Optional[int] = None,
        causal: bool = False,
        cross_attend: bool = False,
        pre_norm: bool = True,
    ) -> None:
        super().__init__()
        attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs)
        norm_fn = partial(norm_fn, dim)
        ff_fn = partial(ff_fn, dim=dim, **ff_kwargs)
        self.layer_types = self._get_layer_types(cross_attend) * depth
        self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn)
        rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None
        self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None
        self.pre_norm = pre_norm
        self.has_pos_emb = True if self.rotary_emb is not None else False

    @staticmethod
    def _get_layer_types(cross_attend: bool) -> Tuple:
        """Get layer specification."""
        if cross_attend:
            return "a", "c", "f"
        return "a", "f"

    def _build_network(
        self,
        causal: bool,
        attn_fn: partial,
        norm_fn: partial,
        ff_fn: partial,
    ) -> nn.ModuleList:
        """Configures transformer network."""
        layers = nn.ModuleList([])
        for layer_type in self.layer_types:
            if layer_type == "a":
                layer = attn_fn(causal=causal)
            elif layer_type == "c":
                layer = attn_fn()
            elif layer_type == "f":
                layer = ff_fn()

            residual_fn = Residual()

            layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
        return layers

    def forward(
        self,
        x: Tensor,
        context: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        context_mask: Optional[Tensor] = None,
    ) -> Tensor:
        rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None

        for i, (layer_type, (norm, block, residual_fn)) in enumerate(
            zip(self.layer_types, self.layers)
        ):
            is_last = i == len(self.layers) - 1

            residual = x

            if self.pre_norm:
                x = norm(x)

            if layer_type == "a":
                out, _ = block(x=x, mask=mask, rotary_pos_emb=rotary_pos_emb)
            elif layer_type == "c":
                out, _ = block(x, context=context, mask=mask, context_mask=context_mask)
            elif layer_type == "f":
                out = block(x)

            x = residual_fn(out, residual)

            if not self.pre_norm and not is_last:
                x = norm(x)

        return x


class Encoder(AttentionLayers):
    def __init__(self, **kwargs: Any) -> None:
        assert "causal" not in kwargs, "Cannot set causality on encoder"
        super().__init__(causal=False, **kwargs)


class Decoder(AttentionLayers):
    def __init__(self, **kwargs: Any) -> None:
        assert "causal" not in kwargs, "Cannot set causality on decoder"
        super().__init__(causal=True, **kwargs)