summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/layers.py
blob: 67558ad983e9d29299531a50ee1cb40ee7807037 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Transformer attention layer."""
from copy import deepcopy
from typing import Any, Optional, Tuple, Type

import attr
from torch import nn, Tensor

from text_recognizer.networks.transformer.attention import Attention
from text_recognizer.networks.transformer.mlp import FeedForward
from text_recognizer.networks.transformer.residual import Residual


@attr.s(eq=False)
class AttentionLayers(nn.Module):
    """Standard transfomer layer."""

    def __attrs_pre_init__(self) -> None:
        """Pre init constructor."""
        super().__init__()

    depth: int = attr.ib()
    self_attn: Attention = attr.ib()
    norm: Type[nn.Module] = attr.ib()
    ff: FeedForward = attr.ib()
    cross_attn: Optional[Attention] = attr.ib(default=None)
    pre_norm: bool = attr.ib(default=True)
    has_pos_emb: bool = attr.ib(default=False)
    layer_types: Tuple[str, ...] = attr.ib(init=False)
    layers: nn.ModuleList = attr.ib(init=False)

    def __attrs_post_init__(self) -> None:
        """Post init configuration."""
        self.layer_types = self._get_layer_types() * self.depth
        self.layers = self._build_network()

    def _get_layer_types(self) -> Tuple:
        """Get layer specification."""
        if self.cross_attn is not None:
            return "a", "c", "f"
        return "a", "f"

    def _delete(self) -> None:
        del self.self_attn
        del self.ff
        del self.norm
        del self.cross_attn

    def _build_network(self) -> nn.ModuleList:
        """Configures transformer network."""
        layers = nn.ModuleList([])
        for layer_type in self.layer_types:
            if layer_type == "a":
                layer = deepcopy(self.self_attn)
            elif layer_type == "c":
                layer = deepcopy(self.cross_attn)
            elif layer_type == "f":
                layer = deepcopy(self.ff)
            layers.append(nn.ModuleList([deepcopy(self.norm), layer, Residual()]))
        self._delete()
        return layers

    def forward(
        self,
        x: Tensor,
        context: Optional[Tensor] = None,
        input_mask: Optional[Tensor] = None,
        context_mask: Optional[Tensor] = None,
    ) -> Tensor:
        """Forward pass."""
        for i, (layer_type, (norm, block, residual_fn)) in enumerate(
            zip(self.layer_types, self.layers)
        ):
            is_last = i == len(self.layers) - 1
            residual = x

            if self.pre_norm:
                x = norm(x)

            if layer_type == "a":
                out = block(x=x, input_mask=input_mask)
            elif layer_type == "c":
                out = block(
                    x, context=context, input_mask=input_mask, context_mask=context_mask
                )
            elif layer_type == "f":
                out = block(x)

            x = residual_fn(out, residual)

            if not self.pre_norm and not is_last:
                x = norm(x)

        return x


class Decoder(AttentionLayers):
    """Decoder module."""

    def __init__(self, **kwargs: Any) -> None:
        super().__init__(**kwargs)