summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/layers.py
blob: f740244dd79afc300e21b97598da1b9f802871dc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""Transformer attention layer."""
from copy import deepcopy
from typing import Any, Optional, Tuple, Type

import attr
from torch import nn, Tensor

from text_recognizer.networks.transformer.attention import Attention
from text_recognizer.networks.transformer.local_attention import LocalAttention
from text_recognizer.networks.transformer.mlp import FeedForward
from text_recognizer.networks.transformer.residual import Residual


@attr.s(eq=False)
class AttentionLayers(nn.Module):
    """Standard transfomer layer."""

    def __attrs_pre_init__(self) -> None:
        """Pre init constructor."""
        super().__init__()

    depth: int = attr.ib()
    self_attn: Attention = attr.ib()
    norm: Type[nn.Module] = attr.ib()
    ff: FeedForward = attr.ib()
    cross_attn: Optional[Attention] = attr.ib(default=None)
    local_self_attn: Optional[LocalAttention] = attr.ib(default=None)
    pre_norm: bool = attr.ib(default=True)
    local_depth: Optional[int] = attr.ib(default=None)
    has_pos_emb: bool = attr.ib(default=False)
    layer_types: Tuple[str, ...] = attr.ib(init=False)
    layers: nn.ModuleList = attr.ib(init=False)

    def __attrs_post_init__(self) -> None:
        """Post init configuration."""
        if self.local_self_attn is not None:
            if self.local_depth is None:
                ValueError("Local depth has to be specified")
        self.layer_types = self._get_layer_types() * self.depth
        self.layers = self._build_network()

    def _get_layer_types(self) -> Tuple:
        """Get layer specification."""
        if self.cross_attn is not None:
            return "a", "c", "f"
        return "a", "f"

    def _self_attn_block(self, i: int) -> Type[nn.Module]:
        if self.local_depth is not None and i < self.local_depth:
            return deepcopy(self.local_self_attn)
        return deepcopy(self.self_attn)

    def _delete(self) -> None:
        del self.self_attn
        del self.local_self_attn
        del self.ff
        del self.norm
        del self.cross_attn

    def _build_network(self) -> nn.ModuleList:
        """Configures transformer network."""
        layers = nn.ModuleList([])
        self_attn_depth = 0
        for layer_type in self.layer_types:
            if layer_type == "a":
                layer = self._self_attn_block(self_attn_depth)
                self_attn_depth += 1
            elif layer_type == "c":
                layer = deepcopy(self.cross_attn)
            elif layer_type == "f":
                layer = deepcopy(self.ff)
            layers.append(nn.ModuleList([deepcopy(self.norm), layer, Residual()]))
        self._delete()
        return layers

    def forward(
        self,
        x: Tensor,
        context: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        context_mask: Optional[Tensor] = None,
    ) -> Tensor:
        """Forward pass."""
        for i, (layer_type, (norm, block, residual_fn)) in enumerate(
            zip(self.layer_types, self.layers)
        ):
            is_last = i == len(self.layers) - 1
            residual = x

            if self.pre_norm:
                x = norm(x)

            if layer_type == "a":
                out = block(x=x, mask=mask)
            elif layer_type == "c":
                out = block(x, context=context, mask=mask, context_mask=context_mask)
            elif layer_type == "f":
                out = block(x)

            x = residual_fn(out, residual)

            if not self.pre_norm and not is_last:
                x = norm(x)

        return x


class Decoder(AttentionLayers):
    """Decoder module."""

    def __init__(self, **kwargs: Any) -> None:
        super().__init__(**kwargs)