summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/embeddings/axial.py
blob: 56f29c5b0caecf8647a3bf65cb9addcf8a5e629c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""Axial attention for multi-dimensional data.

Stolen from:
    https://github.com/lucidrains/axial-attention/blob/eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100
"""
from typing import Sequence

import torch
from torch import nn, Tensor


class AxialPositionalEmbedding(nn.Module):
    """Axial positional embedding."""

    def __init__(self, dim: int, shape: Sequence[int], emb_dim_index: int = 1) -> None:
        super().__init__()
        total_dimensions = len(shape) + 2
        ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]

        self.num_axials = len(shape)

        for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
            shape = [1] * total_dimensions
            shape[emb_dim_index] = dim
            shape[axial_dim_index] = axial_dim
            parameter = nn.Parameter(torch.randn(*shape))
            setattr(self, f"param_{i}", parameter)

    def forward(self, x: Tensor) -> Tensor:
        """Applies axial positional embedding."""
        for i in range(self.num_axials):
            x = x + getattr(self, f"param_{i}")
        return x