summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/coat/positional_encodings.py
blob: 925db04656ea50916806c5e72afc63b1f9b3187f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Positional encodings for input sequence to transformer."""
from typing import Dict, Union, Tuple

from einops import rearrange
from loguru import logger
import torch
from torch import nn
from torch import Tensor


class RelativeEncoding(nn.Module):
    """Relative positional encoding."""
    def __init__(self, channels: int, heads: int, windows: Union[int, Dict[int, int]]) -> None:
        super().__init__()
        self.windows = {windows: heads} if isinstance(windows, int) else windows
        self.heads = list(self.windows.values())
        self.channel_heads = [head * channels for head in self.heads]
        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels=head * channels,
                             out_channels=head * channels,
                             kernel_shape=window,
                             padding=window // 2,
                             dilation=1,
                             groups=head * channels,
            ) for window, head in self.windows.items()])

    def forward(self, q: Tensor, v: Tensor, shape: Tuple[int, int]) -> Tensor:
        """Applies relative positional encoding."""
        b, heads, hw, c = q.shape
        h, w = shape
        if hw != h * w:
            logger.exception(f"Query width {hw} neq to height x width {h * w}")
            raise ValueError
        
        v = rearrange(v, "b heads (h w) c -> b (heads c) h w", h=h, w=w)
        v = torch.split(v, self.channel_heads, dim=1)
        v = [conv(x) for conv, x in zip(self.convs, v)]
        v = torch.cat(v, dim=1)
        v = rearrange(v, "b (heads c) h w -> b heads (h w) c", heads=heads)

        encoding = q * v
        zeros = torch.zeros((b, heads, 1, c), dtype=q.dtype, layout=q.layout, device=q.device)
        encoding = torch.cat((zeros, encoding), dim=2)
        return encoding


class PositionalEncoding(nn.Module):
    """Convolutional positional encoding."""
    def __init__(self, dim: int, k: int = 3) -> None:
        super().__init__()
        self.encode = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=k, stride=1, padding=k//2, groups=dim)

    def forward(self, x: Tensor, shape: Tuple[int, int]) -> Tensor:
        """Applies convolutional encoding."""
        _, hw, _ = x.shape
        h, w = shape

        if hw != h * w:
            logger.exception(f"Query width {hw} neq to height x width {h * w}")
            raise ValueError

        # Depthwise convolution.
        x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
        x = self.encode(x) + x
        x = rearrange(x, "b c h w -> b (h w) c")
        return x