summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/layers.py
blob: 1c951ae4d6b222260e81763f5d5f2795b9e96662 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Generates the attention layer architecture."""
from functools import partial
from typing import Dict, Optional, Type

from click.types import Tuple

import torch
from torch import nn, Tensor

from .attention import Attention
from .mlp import FeedForward
from .residual import Residual


class AttentionLayers(nn.Module):
    def __init__(
        self,
        dim: int,
        depth: int,
        num_heads: int,
        ff_kwargs: Dict,
        attn_kwargs: Dict,
        attn_fn: Type[nn.Module] = Attention,
        norm_fn: Type[nn.Module] = nn.LayerNorm,
        ff_fn: Type[nn.Module] = FeedForward,
        residual_fn: Type[nn.Module] = Residual,
        causal: bool = False,
        cross_attend: bool = False,
    ) -> None:
        super().__init__()
        attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs)
        norm_fn = partial(norm_fn, dim=dim)
        ff_fn = partial(ff_fn, dim=dim, **ff_kwargs)
        layer_types = self._get_layer_types(cross_attend) * depth
        self.layers = self._build_network(
            layer_types, causal, attn_fn, norm_fn, ff_fn, residual_fn
        )

    @staticmethod
    def _get_layer_types(cross_attend: bool) -> Tuple:
        """Get layer specification."""
        if cross_attend:
            return "a", "c", "f"
        return "a", "f"

    @staticmethod
    def _build_network(
        layer_types: Tuple,
        causal: bool,
        attn_fn: partial,
        norm_fn: partial,
        ff_fn: partial,
        residual_fn: Type[nn.Module],
    ) -> nn.ModuleList:
        """Configures transformer layers."""
        layers = nn.ModuleList([])
        for layer_type in layer_types:
            if layer_type == "a":
                layer = attn_fn(causal=causal)
            elif layer_type == "c":
                layer = attn_fn()
            elif layer_type == "f":
                layer = ff_fn()

            residual_fn = residual_fn()

            layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
        return layers

    def forward(
        self,
        x: Tensor,
        context: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        context_mask: Optional[Tensor] = None,
    ) -> Tensor:
        pass