diff options
-rw-r--r-- | text_recognizer/network/convnext/__init__.py | 7 | ||||
-rw-r--r-- | text_recognizer/network/convnext/convnext.py | 41 | ||||
-rw-r--r-- | text_recognizer/network/convnext/downsample.py | 9 | ||||
-rw-r--r-- | text_recognizer/network/convnext/residual.py | 16 | ||||
-rw-r--r-- | text_recognizer/network/convnext/transformer.py (renamed from text_recognizer/network/convnext/attention.py) | 35 |
5 files changed, 42 insertions, 66 deletions
diff --git a/text_recognizer/network/convnext/__init__.py b/text_recognizer/network/convnext/__init__.py index dcff3fc..e69de29 100644 --- a/text_recognizer/network/convnext/__init__.py +++ b/text_recognizer/network/convnext/__init__.py @@ -1,7 +0,0 @@ -"""Convnext module.""" -from text_recognizer.network.convnext.attention import ( - Attention, - FeedForward, - TransformerBlock, -) -from text_recognizer.network.convnext.convnext import ConvNext diff --git a/text_recognizer/network/convnext/convnext.py b/text_recognizer/network/convnext/convnext.py index 6acf059..8eea9df 100644 --- a/text_recognizer/network/convnext/convnext.py +++ b/text_recognizer/network/convnext/convnext.py @@ -1,11 +1,27 @@ """ConvNext module.""" from typing import Optional, Sequence +import torch from torch import Tensor, nn -from text_recognizer.network.convnext.attention import TransformerBlock -from text_recognizer.network.convnext.downsample import Downsample -from text_recognizer.network.convnext.norm import LayerNorm +from .transformer import Transformer +from .downsample import Downsample +from .norm import LayerNorm + + +class GRN(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.gamma = nn.Parameter(torch.zeros(dim, 1, 1)) + self.bias = nn.Parameter(torch.zeros(dim, 1, 1)) + + def forward(self, x): + spatial_l2_norm = x.norm(p=2, dim=(2, 3), keepdim=True) + feat_norm = spatial_l2_norm / spatial_l2_norm.mean(dim=-1, keepdim=True).clamp( + min=self.eps + ) + return x * feat_norm * self.gamma + self.bias + x class ConvNextBlock(nn.Module): @@ -13,14 +29,14 @@ class ConvNextBlock(nn.Module): def __init__(self, dim: int, dim_out: int, mult: int) -> None: super().__init__() - self.ds_conv = nn.Conv2d( - dim, dim, kernel_size=(7, 7), padding="same", groups=dim - ) + inner_dim = mult * dim_out + self.ds_conv = nn.Conv2d(dim, dim, kernel_size=7, padding="same", groups=dim) self.net = nn.Sequential( LayerNorm(dim), - nn.Conv2d(dim, dim_out * mult, kernel_size=(3, 3), padding="same"), + nn.Conv2d(dim, inner_dim, kernel_size=3, stride=1, padding="same"), nn.GELU(), - nn.Conv2d(dim_out * mult, dim_out, kernel_size=(3, 3), padding="same"), + GRN(inner_dim), + nn.Conv2d(inner_dim, dim_out, kernel_size=3, stride=1, padding="same"), ) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() @@ -36,8 +52,7 @@ class ConvNext(nn.Module): dim: int = 16, dim_mults: Sequence[int] = (2, 4, 8), depths: Sequence[int] = (3, 3, 6), - downsampling_factors: Sequence[Sequence[int]] = ((2, 2), (2, 2), (2, 2)), - attn: Optional[TransformerBlock] = None, + attn: Optional[Transformer] = None, ) -> None: super().__init__() dims = (dim, *map(lambda m: m * dim, dim_mults)) @@ -51,11 +66,10 @@ class ConvNext(nn.Module): self.layers.append( nn.ModuleList( [ - ConvNextBlock(dim_in, dim_in, 2), nn.ModuleList( [ConvNextBlock(dim_in, dim_in, 2) for _ in range(depths[i])] ), - Downsample(dim_in, dim_out, downsampling_factors[i]), + Downsample(dim_in, dim_out), ] ) ) @@ -68,8 +82,7 @@ class ConvNext(nn.Module): def forward(self, x: Tensor) -> Tensor: x = self.stem(x) - for init_block, blocks, down in self.layers: - x = init_block(x) + for blocks, down in self.layers: for fn in blocks: x = fn(x) x = down(x) diff --git a/text_recognizer/network/convnext/downsample.py b/text_recognizer/network/convnext/downsample.py index a8a0466..dcc14aa 100644 --- a/text_recognizer/network/convnext/downsample.py +++ b/text_recognizer/network/convnext/downsample.py @@ -1,6 +1,4 @@ """Convnext downsample module.""" -from typing import Tuple - from einops.layers.torch import Rearrange from torch import Tensor, nn @@ -8,12 +6,11 @@ from torch import Tensor, nn class Downsample(nn.Module): """Downsamples feature maps by patches.""" - def __init__(self, dim: int, dim_out: int, factors: Tuple[int, int]) -> None: + def __init__(self, dim: int, dim_out: int) -> None: super().__init__() - s1, s2 = factors self.fn = nn.Sequential( - Rearrange("b c (h s1) (w s2) -> b (c s1 s2) h w", s1=s1, s2=s2), - nn.Conv2d(dim * s1 * s2, dim_out, 1), + Rearrange("b c (h s1) (w s2) -> b (c s1 s2) h w", s1=2, s2=2), + nn.Conv2d(dim * 4, dim_out, 1), ) def forward(self, x: Tensor) -> Tensor: diff --git a/text_recognizer/network/convnext/residual.py b/text_recognizer/network/convnext/residual.py deleted file mode 100644 index dfc2847..0000000 --- a/text_recognizer/network/convnext/residual.py +++ /dev/null @@ -1,16 +0,0 @@ -"""Generic residual layer.""" -from typing import Callable - -from torch import Tensor, nn - - -class Residual(nn.Module): - """Residual layer.""" - - def __init__(self, fn: Callable) -> None: - super().__init__() - self.fn = fn - - def forward(self, x: Tensor) -> Tensor: - """Applies residual fn.""" - return self.fn(x) + x diff --git a/text_recognizer/network/convnext/attention.py b/text_recognizer/network/convnext/transformer.py index 6bc9692..6c53c48 100644 --- a/text_recognizer/network/convnext/attention.py +++ b/text_recognizer/network/convnext/transformer.py @@ -1,29 +1,21 @@ """Convolution self attention block.""" -import torch.nn.functional as F from einops import rearrange from torch import Tensor, einsum, nn from text_recognizer.network.convnext.norm import LayerNorm -from text_recognizer.network.convnext.residual import Residual - - -def l2norm(t: Tensor) -> Tensor: - return F.normalize(t, dim=-1) class FeedForward(nn.Module): def __init__(self, dim: int, mult: int = 4) -> None: super().__init__() inner_dim = int(dim * mult) - self.fn = Residual( - nn.Sequential( - LayerNorm(dim), - nn.Conv2d(dim, inner_dim, 1, bias=False), - nn.GELU(), - LayerNorm(inner_dim), - nn.Conv2d(inner_dim, dim, 1, bias=False), - ) + self.fn = nn.Sequential( + LayerNorm(dim), + nn.Conv2d(dim, inner_dim, 1, bias=False), + nn.GELU(), + LayerNorm(inner_dim), + nn.Conv2d(inner_dim, dim, 1, bias=False), ) def forward(self, x: Tensor) -> Tensor: @@ -46,8 +38,6 @@ class Attention(nn.Module): def forward(self, x: Tensor) -> Tensor: h, w = x.shape[-2:] - residual = x.clone() - x = self.norm(x) q, k, v = self.to_qkv(x).chunk(3, dim=1) @@ -56,24 +46,23 @@ class Attention(nn.Module): (q, k, v), ) - q, k = map(l2norm, (q, k)) - - sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale + q = q * self.scale + sim = einsum("b h i d, b h j d -> b h i j", q, k) attn = sim.softmax(dim=-1) out = einsum("b h i j, b h j d -> b h i d", attn, v) out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) - return self.to_out(out) + residual + return self.to_out(out) -class TransformerBlock(nn.Module): +class Transformer(nn.Module): def __init__(self, attn: Attention, ff: FeedForward) -> None: super().__init__() self.attn = attn self.ff = ff def forward(self, x: Tensor) -> Tensor: - x = self.attn(x) - x = self.ff(x) + x = x + self.attn(x) + x = x + self.ff(x) return x |