diff options
Diffstat (limited to 'text_recognizer/networks/coat')
-rw-r--r-- | text_recognizer/networks/coat/__init__.py | 0 | ||||
-rw-r--r-- | text_recognizer/networks/coat/factor_attention.py | 9 | ||||
-rw-r--r-- | text_recognizer/networks/coat/patch_embedding.py | 38 | ||||
-rw-r--r-- | text_recognizer/networks/coat/positional_encodings.py | 76 |
4 files changed, 0 insertions, 123 deletions
diff --git a/text_recognizer/networks/coat/__init__.py b/text_recognizer/networks/coat/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/text_recognizer/networks/coat/__init__.py +++ /dev/null diff --git a/text_recognizer/networks/coat/factor_attention.py b/text_recognizer/networks/coat/factor_attention.py deleted file mode 100644 index f91c5ef..0000000 --- a/text_recognizer/networks/coat/factor_attention.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Factorized attention with convolutional relative positional encodings.""" -from torch import nn - - -class FactorAttention(nn.Module): - """Factorized attention with relative positional encodings.""" - def __init__(self, dim: int, num_heads: int) -> None: - pass - diff --git a/text_recognizer/networks/coat/patch_embedding.py b/text_recognizer/networks/coat/patch_embedding.py deleted file mode 100644 index 3b7b76a..0000000 --- a/text_recognizer/networks/coat/patch_embedding.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Patch embedding for images and feature maps.""" -from typing import Sequence, Tuple - -from einops import rearrange -from loguru import logger -from torch import nn -from torch import Tensor - - -class PatchEmbedding(nn.Module): - """Patch embedding of images.""" - - def __init__( - self, - image_shape: Sequence[int], - patch_size: int = 16, - in_channels: int = 1, - embedding_dim: int = 512, - ) -> None: - if image_shape[0] % patch_size == 0 and image_shape[1] % patch_size == 0: - logger.error( - f"Image shape {image_shape} not divisable by patch size {patch_size}" - ) - - self.patch_size = patch_size - self.embedding = nn.Conv2d( - in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size - ) - self.norm = nn.LayerNorm(embedding_dim) - - def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]: - """Embeds image or feature maps with patch embedding.""" - _, _, h, w = x.shape - h_out, w_out = h // self.patch_size, w // self.patch_size - x = self.embedding(x) - x = rearrange(x, "b c h w -> b (h w) c") - x = self.norm(x) - return x, (h_out, w_out) diff --git a/text_recognizer/networks/coat/positional_encodings.py b/text_recognizer/networks/coat/positional_encodings.py deleted file mode 100644 index 925db04..0000000 --- a/text_recognizer/networks/coat/positional_encodings.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Positional encodings for input sequence to transformer.""" -from typing import Dict, Union, Tuple - -from einops import rearrange -from loguru import logger -import torch -from torch import nn -from torch import Tensor - - -class RelativeEncoding(nn.Module): - """Relative positional encoding.""" - def __init__(self, channels: int, heads: int, windows: Union[int, Dict[int, int]]) -> None: - super().__init__() - self.windows = {windows: heads} if isinstance(windows, int) else windows - self.heads = list(self.windows.values()) - self.channel_heads = [head * channels for head in self.heads] - self.convs = nn.ModuleList([ - nn.Conv2d(in_channels=head * channels, - out_channels=head * channels, - kernel_shape=window, - padding=window // 2, - dilation=1, - groups=head * channels, - ) for window, head in self.windows.items()]) - - def forward(self, q: Tensor, v: Tensor, shape: Tuple[int, int]) -> Tensor: - """Applies relative positional encoding.""" - b, heads, hw, c = q.shape - h, w = shape - if hw != h * w: - logger.exception(f"Query width {hw} neq to height x width {h * w}") - raise ValueError - - v = rearrange(v, "b heads (h w) c -> b (heads c) h w", h=h, w=w) - v = torch.split(v, self.channel_heads, dim=1) - v = [conv(x) for conv, x in zip(self.convs, v)] - v = torch.cat(v, dim=1) - v = rearrange(v, "b (heads c) h w -> b heads (h w) c", heads=heads) - - encoding = q * v - zeros = torch.zeros((b, heads, 1, c), dtype=q.dtype, layout=q.layout, device=q.device) - encoding = torch.cat((zeros, encoding), dim=2) - return encoding - - -class PositionalEncoding(nn.Module): - """Convolutional positional encoding.""" - def __init__(self, dim: int, k: int = 3) -> None: - super().__init__() - self.encode = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=k, stride=1, padding=k//2, groups=dim) - - def forward(self, x: Tensor, shape: Tuple[int, int]) -> Tensor: - """Applies convolutional encoding.""" - _, hw, _ = x.shape - h, w = shape - - if hw != h * w: - logger.exception(f"Query width {hw} neq to height x width {h * w}") - raise ValueError - - # Depthwise convolution. - x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) - x = self.encode(x) + x - x = rearrange(x, "b c h w -> b (h w) c") - return x - - - - - - - - - - |