summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/layers.py
blob: 2b8427dadd61accf7c9498ab92275c9742e51fec (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""Transformer attention layer."""
from typing import Any, Dict, Optional, Tuple

import attr
from torch import nn, Tensor

from text_recognizer.networks.transformer.embeddings.rotary import RotaryEmbedding
from text_recognizer.networks.transformer.residual import Residual
from text_recognizer.networks.util import load_partial_fn


@attr.s(eq=False)
class AttentionLayers(nn.Module):
    """Standard transfomer layer."""

    def __attrs_pre_init__(self) -> None:
        super().__init__()

    dim: int = attr.ib()
    depth: int = attr.ib()
    num_heads: int = attr.ib()
    attn_fn: str = attr.ib()
    attn_kwargs: Dict = attr.ib()
    norm_fn: str = attr.ib()
    ff_fn: str = attr.ib()
    ff_kwargs: Dict = attr.ib()
    rotary_emb: Optional[RotaryEmbedding] = attr.ib()
    causal: bool = attr.ib(default=False)
    cross_attend: bool = attr.ib(default=False)
    pre_norm: bool = attr.ib(default=True)
    layer_types: Tuple[str, ...] = attr.ib(init=False)
    layers: nn.ModuleList = attr.ib(init=False)

    def __attrs_post_init__(self) -> None:
        """Post init configuration."""
        self.layer_types = self._get_layer_types() * self.depth
        self.layers = self._build_network()

    def _get_layer_types(self) -> Tuple:
        """Get layer specification."""
        if self.cross_attend:
            return "a", "c", "f"
        return "a", "f"

    def _build_network(self) -> nn.ModuleList:
        """Configures transformer network."""
        attn = load_partial_fn(
            self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
        )
        norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim)
        ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
        layers = nn.ModuleList([])
        for layer_type in self.layer_types:
            if layer_type == "a":
                layer = attn(causal=self.causal)
            elif layer_type == "c":
                layer = attn()
            elif layer_type == "f":
                layer = ff()
            residual_fn = Residual()
            layers.append(nn.ModuleList([norm(), layer, residual_fn]))
        return layers

    def forward(
        self,
        x: Tensor,
        context: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        context_mask: Optional[Tensor] = None,
    ) -> Tensor:
        rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None
        for i, (layer_type, (norm, block, residual_fn)) in enumerate(
            zip(self.layer_types, self.layers)
        ):
            is_last = i == len(self.layers) - 1
            residual = x

            if self.pre_norm:
                x = norm(x)

            if layer_type == "a":
                out, _ = block(x=x, mask=mask, rotary_pos_emb=rotary_pos_emb)
            elif layer_type == "c":
                out, _ = block(x, context=context, mask=mask, context_mask=context_mask)
            elif layer_type == "f":
                out = block(x)

            x = residual_fn(out, residual)

            if not self.pre_norm and not is_last:
                x = norm(x)

        return x


@attr.s(auto_attribs=True, eq=False)
class Encoder(AttentionLayers):
    causal: bool = attr.ib(default=False, init=False)


class Decoder(AttentionLayers):
    def __init__(self, **kwargs: Any) -> None:
        assert "causal" not in kwargs, "Cannot set causality on decoder"
        super().__init__(causal=True, **kwargs)