From 4e60c836fb710baceba570c28c06437db3ad5c9b Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 24 Apr 2021 23:09:20 +0200 Subject: Implementing CoaT transformer, continue tomorrow... --- text_recognizer/networks/coat/__init__.py | 0 text_recognizer/networks/coat/factor_attention.py | 9 +++ text_recognizer/networks/coat/patch_embedding.py | 38 +++++++++++ .../networks/coat/positional_encodings.py | 76 ++++++++++++++++++++++ 4 files changed, 123 insertions(+) create mode 100644 text_recognizer/networks/coat/__init__.py create mode 100644 text_recognizer/networks/coat/factor_attention.py create mode 100644 text_recognizer/networks/coat/patch_embedding.py create mode 100644 text_recognizer/networks/coat/positional_encodings.py (limited to 'text_recognizer') diff --git a/text_recognizer/networks/coat/__init__.py b/text_recognizer/networks/coat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/text_recognizer/networks/coat/factor_attention.py b/text_recognizer/networks/coat/factor_attention.py new file mode 100644 index 0000000..f91c5ef --- /dev/null +++ b/text_recognizer/networks/coat/factor_attention.py @@ -0,0 +1,9 @@ +"""Factorized attention with convolutional relative positional encodings.""" +from torch import nn + + +class FactorAttention(nn.Module): + """Factorized attention with relative positional encodings.""" + def __init__(self, dim: int, num_heads: int) -> None: + pass + diff --git a/text_recognizer/networks/coat/patch_embedding.py b/text_recognizer/networks/coat/patch_embedding.py new file mode 100644 index 0000000..3b7b76a --- /dev/null +++ b/text_recognizer/networks/coat/patch_embedding.py @@ -0,0 +1,38 @@ +"""Patch embedding for images and feature maps.""" +from typing import Sequence, Tuple + +from einops import rearrange +from loguru import logger +from torch import nn +from torch import Tensor + + +class PatchEmbedding(nn.Module): + """Patch embedding of images.""" + + def __init__( + self, + image_shape: Sequence[int], + patch_size: int = 16, + in_channels: int = 1, + embedding_dim: int = 512, + ) -> None: + if image_shape[0] % patch_size == 0 and image_shape[1] % patch_size == 0: + logger.error( + f"Image shape {image_shape} not divisable by patch size {patch_size}" + ) + + self.patch_size = patch_size + self.embedding = nn.Conv2d( + in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size + ) + self.norm = nn.LayerNorm(embedding_dim) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]: + """Embeds image or feature maps with patch embedding.""" + _, _, h, w = x.shape + h_out, w_out = h // self.patch_size, w // self.patch_size + x = self.embedding(x) + x = rearrange(x, "b c h w -> b (h w) c") + x = self.norm(x) + return x, (h_out, w_out) diff --git a/text_recognizer/networks/coat/positional_encodings.py b/text_recognizer/networks/coat/positional_encodings.py new file mode 100644 index 0000000..925db04 --- /dev/null +++ b/text_recognizer/networks/coat/positional_encodings.py @@ -0,0 +1,76 @@ +"""Positional encodings for input sequence to transformer.""" +from typing import Dict, Union, Tuple + +from einops import rearrange +from loguru import logger +import torch +from torch import nn +from torch import Tensor + + +class RelativeEncoding(nn.Module): + """Relative positional encoding.""" + def __init__(self, channels: int, heads: int, windows: Union[int, Dict[int, int]]) -> None: + super().__init__() + self.windows = {windows: heads} if isinstance(windows, int) else windows + self.heads = list(self.windows.values()) + self.channel_heads = [head * channels for head in self.heads] + self.convs = nn.ModuleList([ + nn.Conv2d(in_channels=head * channels, + out_channels=head * channels, + kernel_shape=window, + padding=window // 2, + dilation=1, + groups=head * channels, + ) for window, head in self.windows.items()]) + + def forward(self, q: Tensor, v: Tensor, shape: Tuple[int, int]) -> Tensor: + """Applies relative positional encoding.""" + b, heads, hw, c = q.shape + h, w = shape + if hw != h * w: + logger.exception(f"Query width {hw} neq to height x width {h * w}") + raise ValueError + + v = rearrange(v, "b heads (h w) c -> b (heads c) h w", h=h, w=w) + v = torch.split(v, self.channel_heads, dim=1) + v = [conv(x) for conv, x in zip(self.convs, v)] + v = torch.cat(v, dim=1) + v = rearrange(v, "b (heads c) h w -> b heads (h w) c", heads=heads) + + encoding = q * v + zeros = torch.zeros((b, heads, 1, c), dtype=q.dtype, layout=q.layout, device=q.device) + encoding = torch.cat((zeros, encoding), dim=2) + return encoding + + +class PositionalEncoding(nn.Module): + """Convolutional positional encoding.""" + def __init__(self, dim: int, k: int = 3) -> None: + super().__init__() + self.encode = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=k, stride=1, padding=k//2, groups=dim) + + def forward(self, x: Tensor, shape: Tuple[int, int]) -> Tensor: + """Applies convolutional encoding.""" + _, hw, _ = x.shape + h, w = shape + + if hw != h * w: + logger.exception(f"Query width {hw} neq to height x width {h * w}") + raise ValueError + + # Depthwise convolution. + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.encode(x) + x + x = rearrange(x, "b c h w -> b (h w) c") + return x + + + + + + + + + + -- cgit v1.2.3-70-g09d2