summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/convnext
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-08-25 23:19:14 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-08-25 23:19:14 +0200
commit49ca6ade1a19f7f9c702171537fe4be0dfcda66d (patch)
tree20062ed1910758481f3d5fff11159706c7b990c6 /text_recognizer/networks/convnext
parent0421daf6bd97596703f426ba61c401599b538eeb (diff)
Rename and add flash atten
Diffstat (limited to 'text_recognizer/networks/convnext')
-rw-r--r--text_recognizer/networks/convnext/__init__.py7
-rw-r--r--text_recognizer/networks/convnext/attention.py79
-rw-r--r--text_recognizer/networks/convnext/convnext.py77
-rw-r--r--text_recognizer/networks/convnext/downsample.py21
-rw-r--r--text_recognizer/networks/convnext/norm.py18
-rw-r--r--text_recognizer/networks/convnext/residual.py16
6 files changed, 0 insertions, 218 deletions
diff --git a/text_recognizer/networks/convnext/__init__.py b/text_recognizer/networks/convnext/__init__.py
deleted file mode 100644
index faebe6f..0000000
--- a/text_recognizer/networks/convnext/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-"""Convnext module."""
-from text_recognizer.networks.convnext.attention import (
- Attention,
- FeedForward,
- TransformerBlock,
-)
-from text_recognizer.networks.convnext.convnext import ConvNext
diff --git a/text_recognizer/networks/convnext/attention.py b/text_recognizer/networks/convnext/attention.py
deleted file mode 100644
index 1334feb..0000000
--- a/text_recognizer/networks/convnext/attention.py
+++ /dev/null
@@ -1,79 +0,0 @@
-"""Convolution self attention block."""
-
-import torch.nn.functional as F
-from einops import rearrange
-from torch import Tensor, einsum, nn
-
-from text_recognizer.networks.convnext.norm import LayerNorm
-from text_recognizer.networks.convnext.residual import Residual
-
-
-def l2norm(t: Tensor) -> Tensor:
- return F.normalize(t, dim=-1)
-
-
-class FeedForward(nn.Module):
- def __init__(self, dim: int, mult: int = 4) -> None:
- super().__init__()
- inner_dim = int(dim * mult)
- self.fn = Residual(
- nn.Sequential(
- LayerNorm(dim),
- nn.Conv2d(dim, inner_dim, 1, bias=False),
- nn.GELU(),
- LayerNorm(inner_dim),
- nn.Conv2d(inner_dim, dim, 1, bias=False),
- )
- )
-
- def forward(self, x: Tensor) -> Tensor:
- return self.fn(x)
-
-
-class Attention(nn.Module):
- def __init__(
- self, dim: int, heads: int = 4, dim_head: int = 64, scale: int = 8
- ) -> None:
- super().__init__()
- self.scale = scale
- self.heads = heads
- inner_dim = heads * dim_head
- self.norm = LayerNorm(dim)
-
- self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias=False)
- self.to_out = nn.Conv2d(inner_dim, dim, 1, bias=False)
-
- def forward(self, x: Tensor) -> Tensor:
- h, w = x.shape[-2:]
-
- residual = x.clone()
-
- x = self.norm(x)
-
- q, k, v = self.to_qkv(x).chunk(3, dim=1)
- q, k, v = map(
- lambda t: rearrange(t, "b (h c) ... -> b h (...) c", h=self.heads),
- (q, k, v),
- )
-
- q, k = map(l2norm, (q, k))
-
- sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
- attn = sim.softmax(dim=-1)
-
- out = einsum("b h i j, b h j d -> b h i d", attn, v)
-
- out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
- return self.to_out(out) + residual
-
-
-class TransformerBlock(nn.Module):
- def __init__(self, attn: Attention, ff: FeedForward) -> None:
- super().__init__()
- self.attn = attn
- self.ff = ff
-
- def forward(self, x: Tensor) -> Tensor:
- x = self.attn(x)
- x = self.ff(x)
- return x
diff --git a/text_recognizer/networks/convnext/convnext.py b/text_recognizer/networks/convnext/convnext.py
deleted file mode 100644
index 9419a15..0000000
--- a/text_recognizer/networks/convnext/convnext.py
+++ /dev/null
@@ -1,77 +0,0 @@
-"""ConvNext module."""
-from typing import Optional, Sequence
-
-from torch import Tensor, nn
-
-from text_recognizer.networks.convnext.attention import TransformerBlock
-from text_recognizer.networks.convnext.downsample import Downsample
-from text_recognizer.networks.convnext.norm import LayerNorm
-
-
-class ConvNextBlock(nn.Module):
- """ConvNext block."""
-
- def __init__(self, dim: int, dim_out: int, mult: int) -> None:
- super().__init__()
- self.ds_conv = nn.Conv2d(
- dim, dim, kernel_size=(7, 7), padding="same", groups=dim
- )
- self.net = nn.Sequential(
- LayerNorm(dim),
- nn.Conv2d(dim, dim_out * mult, kernel_size=(3, 3), padding="same"),
- nn.GELU(),
- nn.Conv2d(dim_out * mult, dim_out, kernel_size=(3, 3), padding="same"),
- )
- self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
-
- def forward(self, x: Tensor) -> Tensor:
- h = self.ds_conv(x)
- h = self.net(h)
- return h + self.res_conv(x)
-
-
-class ConvNext(nn.Module):
- def __init__(
- self,
- dim: int = 16,
- dim_mults: Sequence[int] = (2, 4, 8),
- depths: Sequence[int] = (3, 3, 6),
- downsampling_factors: Sequence[Sequence[int]] = ((2, 2), (2, 2), (2, 2)),
- attn: Optional[TransformerBlock] = None,
- ) -> None:
- super().__init__()
- dims = (dim, *map(lambda m: m * dim, dim_mults))
- self.attn = attn if attn is not None else nn.Identity()
- self.out_channels = dims[-1]
- self.stem = nn.Conv2d(1, dims[0], kernel_size=7, padding="same")
- self.layers = nn.ModuleList([])
-
- for i in range(len(dims) - 1):
- dim_in, dim_out = dims[i], dims[i + 1]
- self.layers.append(
- nn.ModuleList(
- [
- ConvNextBlock(dim_in, dim_in, 2),
- nn.ModuleList(
- [ConvNextBlock(dim_in, dim_in, 2) for _ in range(depths[i])]
- ),
- Downsample(dim_in, dim_out, downsampling_factors[i]),
- ]
- )
- )
- self.norm = LayerNorm(dims[-1])
-
- def _init_weights(self, m):
- if isinstance(m, (nn.Conv2d, nn.Linear)):
- nn.init.trunc_normal_(m.weight, std=0.02)
- nn.init.constant_(m.bias, 0)
-
- def forward(self, x: Tensor) -> Tensor:
- x = self.stem(x)
- for init_block, blocks, down in self.layers:
- x = init_block(x)
- for fn in blocks:
- x = fn(x)
- x = down(x)
- x = self.attn(x)
- return self.norm(x)
diff --git a/text_recognizer/networks/convnext/downsample.py b/text_recognizer/networks/convnext/downsample.py
deleted file mode 100644
index a8a0466..0000000
--- a/text_recognizer/networks/convnext/downsample.py
+++ /dev/null
@@ -1,21 +0,0 @@
-"""Convnext downsample module."""
-from typing import Tuple
-
-from einops.layers.torch import Rearrange
-from torch import Tensor, nn
-
-
-class Downsample(nn.Module):
- """Downsamples feature maps by patches."""
-
- def __init__(self, dim: int, dim_out: int, factors: Tuple[int, int]) -> None:
- super().__init__()
- s1, s2 = factors
- self.fn = nn.Sequential(
- Rearrange("b c (h s1) (w s2) -> b (c s1 s2) h w", s1=s1, s2=s2),
- nn.Conv2d(dim * s1 * s2, dim_out, 1),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies patch function."""
- return self.fn(x)
diff --git a/text_recognizer/networks/convnext/norm.py b/text_recognizer/networks/convnext/norm.py
deleted file mode 100644
index 3355de9..0000000
--- a/text_recognizer/networks/convnext/norm.py
+++ /dev/null
@@ -1,18 +0,0 @@
-"""Layer norm for conv layers."""
-import torch
-from torch import Tensor, nn
-
-
-class LayerNorm(nn.Module):
- """Layer norm for convolutions."""
-
- def __init__(self, dim: int) -> None:
- super().__init__()
- self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies layer norm."""
- eps = 1e-5 if x.dtype == torch.float32 else 1e-3
- var = torch.var(x, dim=1, unbiased=False, keepdim=True)
- mean = torch.mean(x, dim=1, keepdim=True)
- return (x - mean) / (var + eps).sqrt() * self.gamma
diff --git a/text_recognizer/networks/convnext/residual.py b/text_recognizer/networks/convnext/residual.py
deleted file mode 100644
index dfc2847..0000000
--- a/text_recognizer/networks/convnext/residual.py
+++ /dev/null
@@ -1,16 +0,0 @@
-"""Generic residual layer."""
-from typing import Callable
-
-from torch import Tensor, nn
-
-
-class Residual(nn.Module):
- """Residual layer."""
-
- def __init__(self, fn: Callable) -> None:
- super().__init__()
- self.fn = fn
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies residual fn."""
- return self.fn(x) + x