From 1732ed564a738a42c1bf6e8127ae810f5658cb06 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 3 Sep 2023 22:54:09 +0200 Subject: Revert "Delete convnext" This reverts commit 7239bce214607c70a7a91358586f265b2f74de7b. --- text_recognizer/network/convnext/__init__.py | 7 +++ text_recognizer/network/convnext/attention.py | 79 ++++++++++++++++++++++++++ text_recognizer/network/convnext/convnext.py | 77 +++++++++++++++++++++++++ text_recognizer/network/convnext/downsample.py | 21 +++++++ text_recognizer/network/convnext/norm.py | 18 ++++++ text_recognizer/network/convnext/residual.py | 16 ++++++ 6 files changed, 218 insertions(+) create mode 100644 text_recognizer/network/convnext/__init__.py create mode 100644 text_recognizer/network/convnext/attention.py create mode 100644 text_recognizer/network/convnext/convnext.py create mode 100644 text_recognizer/network/convnext/downsample.py create mode 100644 text_recognizer/network/convnext/norm.py create mode 100644 text_recognizer/network/convnext/residual.py (limited to 'text_recognizer') diff --git a/text_recognizer/network/convnext/__init__.py b/text_recognizer/network/convnext/__init__.py new file mode 100644 index 0000000..dcff3fc --- /dev/null +++ b/text_recognizer/network/convnext/__init__.py @@ -0,0 +1,7 @@ +"""Convnext module.""" +from text_recognizer.network.convnext.attention import ( + Attention, + FeedForward, + TransformerBlock, +) +from text_recognizer.network.convnext.convnext import ConvNext diff --git a/text_recognizer/network/convnext/attention.py b/text_recognizer/network/convnext/attention.py new file mode 100644 index 0000000..6bc9692 --- /dev/null +++ b/text_recognizer/network/convnext/attention.py @@ -0,0 +1,79 @@ +"""Convolution self attention block.""" + +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor, einsum, nn + +from text_recognizer.network.convnext.norm import LayerNorm +from text_recognizer.network.convnext.residual import Residual + + +def l2norm(t: Tensor) -> Tensor: + return F.normalize(t, dim=-1) + + +class FeedForward(nn.Module): + def __init__(self, dim: int, mult: int = 4) -> None: + super().__init__() + inner_dim = int(dim * mult) + self.fn = Residual( + nn.Sequential( + LayerNorm(dim), + nn.Conv2d(dim, inner_dim, 1, bias=False), + nn.GELU(), + LayerNorm(inner_dim), + nn.Conv2d(inner_dim, dim, 1, bias=False), + ) + ) + + def forward(self, x: Tensor) -> Tensor: + return self.fn(x) + + +class Attention(nn.Module): + def __init__( + self, dim: int, heads: int = 4, dim_head: int = 64, scale: int = 8 + ) -> None: + super().__init__() + self.scale = scale + self.heads = heads + inner_dim = heads * dim_head + self.norm = LayerNorm(dim) + + self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(inner_dim, dim, 1, bias=False) + + def forward(self, x: Tensor) -> Tensor: + h, w = x.shape[-2:] + + residual = x.clone() + + x = self.norm(x) + + q, k, v = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map( + lambda t: rearrange(t, "b (h c) ... -> b h (...) c", h=self.heads), + (q, k, v), + ) + + q, k = map(l2norm, (q, k)) + + sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale + attn = sim.softmax(dim=-1) + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) + return self.to_out(out) + residual + + +class TransformerBlock(nn.Module): + def __init__(self, attn: Attention, ff: FeedForward) -> None: + super().__init__() + self.attn = attn + self.ff = ff + + def forward(self, x: Tensor) -> Tensor: + x = self.attn(x) + x = self.ff(x) + return x diff --git a/text_recognizer/network/convnext/convnext.py b/text_recognizer/network/convnext/convnext.py new file mode 100644 index 0000000..6acf059 --- /dev/null +++ b/text_recognizer/network/convnext/convnext.py @@ -0,0 +1,77 @@ +"""ConvNext module.""" +from typing import Optional, Sequence + +from torch import Tensor, nn + +from text_recognizer.network.convnext.attention import TransformerBlock +from text_recognizer.network.convnext.downsample import Downsample +from text_recognizer.network.convnext.norm import LayerNorm + + +class ConvNextBlock(nn.Module): + """ConvNext block.""" + + def __init__(self, dim: int, dim_out: int, mult: int) -> None: + super().__init__() + self.ds_conv = nn.Conv2d( + dim, dim, kernel_size=(7, 7), padding="same", groups=dim + ) + self.net = nn.Sequential( + LayerNorm(dim), + nn.Conv2d(dim, dim_out * mult, kernel_size=(3, 3), padding="same"), + nn.GELU(), + nn.Conv2d(dim_out * mult, dim_out, kernel_size=(3, 3), padding="same"), + ) + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + h = self.ds_conv(x) + h = self.net(h) + return h + self.res_conv(x) + + +class ConvNext(nn.Module): + def __init__( + self, + dim: int = 16, + dim_mults: Sequence[int] = (2, 4, 8), + depths: Sequence[int] = (3, 3, 6), + downsampling_factors: Sequence[Sequence[int]] = ((2, 2), (2, 2), (2, 2)), + attn: Optional[TransformerBlock] = None, + ) -> None: + super().__init__() + dims = (dim, *map(lambda m: m * dim, dim_mults)) + self.attn = attn if attn is not None else nn.Identity() + self.out_channels = dims[-1] + self.stem = nn.Conv2d(1, dims[0], kernel_size=7, padding="same") + self.layers = nn.ModuleList([]) + + for i in range(len(dims) - 1): + dim_in, dim_out = dims[i], dims[i + 1] + self.layers.append( + nn.ModuleList( + [ + ConvNextBlock(dim_in, dim_in, 2), + nn.ModuleList( + [ConvNextBlock(dim_in, dim_in, 2) for _ in range(depths[i])] + ), + Downsample(dim_in, dim_out, downsampling_factors[i]), + ] + ) + ) + self.norm = LayerNorm(dims[-1]) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, x: Tensor) -> Tensor: + x = self.stem(x) + for init_block, blocks, down in self.layers: + x = init_block(x) + for fn in blocks: + x = fn(x) + x = down(x) + x = self.attn(x) + return self.norm(x) diff --git a/text_recognizer/network/convnext/downsample.py b/text_recognizer/network/convnext/downsample.py new file mode 100644 index 0000000..a8a0466 --- /dev/null +++ b/text_recognizer/network/convnext/downsample.py @@ -0,0 +1,21 @@ +"""Convnext downsample module.""" +from typing import Tuple + +from einops.layers.torch import Rearrange +from torch import Tensor, nn + + +class Downsample(nn.Module): + """Downsamples feature maps by patches.""" + + def __init__(self, dim: int, dim_out: int, factors: Tuple[int, int]) -> None: + super().__init__() + s1, s2 = factors + self.fn = nn.Sequential( + Rearrange("b c (h s1) (w s2) -> b (c s1 s2) h w", s1=s1, s2=s2), + nn.Conv2d(dim * s1 * s2, dim_out, 1), + ) + + def forward(self, x: Tensor) -> Tensor: + """Applies patch function.""" + return self.fn(x) diff --git a/text_recognizer/network/convnext/norm.py b/text_recognizer/network/convnext/norm.py new file mode 100644 index 0000000..3355de9 --- /dev/null +++ b/text_recognizer/network/convnext/norm.py @@ -0,0 +1,18 @@ +"""Layer norm for conv layers.""" +import torch +from torch import Tensor, nn + + +class LayerNorm(nn.Module): + """Layer norm for convolutions.""" + + def __init__(self, dim: int) -> None: + super().__init__() + self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1)) + + def forward(self, x: Tensor) -> Tensor: + """Applies layer norm.""" + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) + return (x - mean) / (var + eps).sqrt() * self.gamma diff --git a/text_recognizer/network/convnext/residual.py b/text_recognizer/network/convnext/residual.py new file mode 100644 index 0000000..dfc2847 --- /dev/null +++ b/text_recognizer/network/convnext/residual.py @@ -0,0 +1,16 @@ +"""Generic residual layer.""" +from typing import Callable + +from torch import Tensor, nn + + +class Residual(nn.Module): + """Residual layer.""" + + def __init__(self, fn: Callable) -> None: + super().__init__() + self.fn = fn + + def forward(self, x: Tensor) -> Tensor: + """Applies residual fn.""" + return self.fn(x) + x -- cgit v1.2.3-70-g09d2