diff options
Diffstat (limited to 'text_recognizer/networks/coat/positional_encodings.py')
| -rw-r--r-- | text_recognizer/networks/coat/positional_encodings.py | 76 | 
1 files changed, 76 insertions, 0 deletions
diff --git a/text_recognizer/networks/coat/positional_encodings.py b/text_recognizer/networks/coat/positional_encodings.py new file mode 100644 index 0000000..925db04 --- /dev/null +++ b/text_recognizer/networks/coat/positional_encodings.py @@ -0,0 +1,76 @@ +"""Positional encodings for input sequence to transformer.""" +from typing import Dict, Union, Tuple + +from einops import rearrange +from loguru import logger +import torch +from torch import nn +from torch import Tensor + + +class RelativeEncoding(nn.Module): +    """Relative positional encoding.""" +    def __init__(self, channels: int, heads: int, windows: Union[int, Dict[int, int]]) -> None: +        super().__init__() +        self.windows = {windows: heads} if isinstance(windows, int) else windows +        self.heads = list(self.windows.values()) +        self.channel_heads = [head * channels for head in self.heads] +        self.convs = nn.ModuleList([ +            nn.Conv2d(in_channels=head * channels, +                             out_channels=head * channels, +                             kernel_shape=window, +                             padding=window // 2, +                             dilation=1, +                             groups=head * channels, +            ) for window, head in self.windows.items()]) + +    def forward(self, q: Tensor, v: Tensor, shape: Tuple[int, int]) -> Tensor: +        """Applies relative positional encoding.""" +        b, heads, hw, c = q.shape +        h, w = shape +        if hw != h * w: +            logger.exception(f"Query width {hw} neq to height x width {h * w}") +            raise ValueError +         +        v = rearrange(v, "b heads (h w) c -> b (heads c) h w", h=h, w=w) +        v = torch.split(v, self.channel_heads, dim=1) +        v = [conv(x) for conv, x in zip(self.convs, v)] +        v = torch.cat(v, dim=1) +        v = rearrange(v, "b (heads c) h w -> b heads (h w) c", heads=heads) + +        encoding = q * v +        zeros = torch.zeros((b, heads, 1, c), dtype=q.dtype, layout=q.layout, device=q.device) +        encoding = torch.cat((zeros, encoding), dim=2) +        return encoding + + +class PositionalEncoding(nn.Module): +    """Convolutional positional encoding.""" +    def __init__(self, dim: int, k: int = 3) -> None: +        super().__init__() +        self.encode = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=k, stride=1, padding=k//2, groups=dim) + +    def forward(self, x: Tensor, shape: Tuple[int, int]) -> Tensor: +        """Applies convolutional encoding.""" +        _, hw, _ = x.shape +        h, w = shape + +        if hw != h * w: +            logger.exception(f"Query width {hw} neq to height x width {h * w}") +            raise ValueError + +        # Depthwise convolution. +        x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) +        x = self.encode(x) + x +        x = rearrange(x, "b c h w -> b (h w) c") +        return x + + +         + + + +             +             + +          |