summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/embeddings/axial.py
blob: 25d8f60800ec44648e1c137e9ccf005024100d6b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Axial attention for multi-dimensional data.

Stolen from:
    https://github.com/lucidrains/axial-attention/blob/eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100
"""
from functools import reduce
from operator import mul

import torch
from torch import nn


class AxialPositionalEmbedding(nn.Module):
    def __init__(self, dim, axial_shape, axial_dims=None):
        super().__init__()

        self.dim = dim
        self.shape = axial_shape
        self.max_seq_len = reduce(mul, axial_shape, 1)

        self.summed = axial_dims is None
        axial_dims = ((dim,) * len(axial_shape)) if self.summed else axial_dims

        assert len(self.shape) == len(
            axial_dims
        ), "number of axial dimensions must equal the number of dimensions in the shape"
        assert (
            self.summed or not self.summed and sum(axial_dims) == dim
        ), f"axial dimensions must sum up to the target dimension {dim}"

        self.weights = ParameterList(self, "weights", len(axial_shape))

        for ind, (shape, axial_dim) in enumerate(zip(self.shape, axial_dims)):
            ax_shape = [1] * len(self.shape)
            ax_shape[ind] = shape
            ax_shape = (1, *ax_shape, axial_dim)
            ax_emb = nn.Parameter(torch.zeros(ax_shape).normal_(0, 1))
            self.weights.append(ax_emb)

    def forward(self, x):
        b, t, _ = x.shape
        assert (
            t <= self.max_seq_len
        ), f"Sequence length ({t}) must be less than the maximum sequence length allowed ({self.max_seq_len})"
        embs = []

        for ax_emb in self.weights.to_list():
            axial_dim = ax_emb.shape[-1]
            expand_shape = (b, *self.shape, axial_dim)
            emb = ax_emb.expand(expand_shape).reshape(b, self.max_seq_len, axial_dim)
            embs.append(emb)

        pos_emb = sum(embs) if self.summed else torch.cat(embs, dim=-1)
        return pos_emb[:, :t].to(x)


# a mock parameter list object until below issue is resolved
# https://github.com/pytorch/pytorch/issues/36035
class ParameterList(object):
    def __init__(self, kls, prefix, length):
        self.ind = 0
        self.kls = kls
        self.prefix = prefix
        self.length = length

    def _keyname(self, prefix, ind):
        return f"{prefix}_{ind}"

    def append(self, x):
        setattr(self.kls, self._keyname(self.prefix, self.ind), x)
        self.ind += 1

    def to_list(self):
        return [
            getattr(self.kls, self._keyname(self.prefix, i)) for i in range(self.length)
        ]


class AxialPositionalEmbeddingImage(nn.Module):
    def __init__(self, dim, axial_shape, axial_dims=None):
        super().__init__()
        assert len(axial_shape) == 2, "Axial shape must have 2 dimensions for images"
        self.dim = dim
        self.pos_emb = AxialPositionalEmbedding(dim, axial_shape, axial_dims)

    def forward(self, img):
        b, c, h, w = img.shape
        img = img.permute(0, 2, 3, 1).reshape(b, h * w, c)
        pos_emb = self.pos_emb(img)
        return pos_emb.reshape(b, h, w, self.dim).permute(0, 3, 1, 2)