summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/layers.py
blob: 4263f525b57d371998b3622927844fde331fefc3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Transformer attention layer."""
from copy import deepcopy
from typing import Any, Optional, Tuple, Type

from torch import nn, Tensor

from text_recognizer.networks.transformer.attention import Attention
from text_recognizer.networks.transformer.mlp import FeedForward
from text_recognizer.networks.transformer.residual import Residual


class AttentionLayers(nn.Module):
    """Standard transfomer layer."""

    def __init__(
        self,
        depth: int,
        self_attn: Attention,
        norm: Type[nn.Module],
        ff: FeedForward,
        cross_attn: Optional[Attention] = None,
        pre_norm: bool = True,
        has_pos_emb: bool = True,
    ) -> None:
        super().__init__()
        self.pre_norm = pre_norm
        self.has_pos_emb = has_pos_emb
        self.layer_types = self._get_layer_types() * depth
        self.layers = self._build(self_attn, norm, ff, cross_attn)

    def _get_layer_types(self) -> Tuple:
        """Get layer specification."""
        if self.cross_attn is not None:
            return "a", "c", "f"
        return "a", "f"

    def _build(
        self,
        self_attn: Attention,
        norm: Type[nn.Module],
        ff: FeedForward,
        cross_attn: Optional[Attention],
    ) -> nn.ModuleList:
        """Configures transformer network."""
        layers = nn.ModuleList([])
        for layer_type in self.layer_types:
            if layer_type == "a":
                layer = deepcopy(self_attn)
            elif layer_type == "c":
                layer = deepcopy(cross_attn)
            elif layer_type == "f":
                layer = deepcopy(ff)
            layers.append(nn.ModuleList([deepcopy(norm), layer, Residual()]))
        return layers

    def forward(
        self,
        x: Tensor,
        context: Optional[Tensor] = None,
        input_mask: Optional[Tensor] = None,
        context_mask: Optional[Tensor] = None,
    ) -> Tensor:
        """Forward pass."""
        for i, (layer_type, (norm, block, residual_fn)) in enumerate(
            zip(self.layer_types, self.layers)
        ):
            is_last = i == len(self.layers) - 1
            residual = x

            if self.pre_norm:
                x = norm(x)

            if layer_type == "a":
                out = block(x=x, input_mask=input_mask)
            elif layer_type == "c":
                out = block(
                    x, context=context, input_mask=input_mask, context_mask=context_mask
                )
            elif layer_type == "f":
                out = block(x)

            x = residual_fn(out, residual)

            if not self.pre_norm and not is_last:
                x = norm(x)

        return x


class Decoder(AttentionLayers):
    """Decoder module."""

    def __init__(self, **kwargs: Any) -> None:
        if "cross_attn" not in kwargs:
            ValueError("Decoder requires cross attention.")

        super().__init__(**kwargs)


class Encoder(AttentionLayers):
    """Encoder module."""

    def __init__(self, **kwargs: Any) -> None:
        if "cross_attn" in kwargs:
            ValueError("Encoder requires cross attention.")

        super().__init__(**kwargs)