summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-11 22:11:04 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-11 22:11:04 +0200
commit72ce2361f97676fc50ebc6b68b9083a402fa30c5 (patch)
treea22f1122c9d8768aba8e05f38366ab40a6701f7e
parentd784c49566e5a7705539fb4ceb0461bad8f41361 (diff)
Update convnext
-rw-r--r--text_recognizer/network/convnext/__init__.py7
-rw-r--r--text_recognizer/network/convnext/convnext.py41
-rw-r--r--text_recognizer/network/convnext/downsample.py9
-rw-r--r--text_recognizer/network/convnext/residual.py16
-rw-r--r--text_recognizer/network/convnext/transformer.py (renamed from text_recognizer/network/convnext/attention.py)35
5 files changed, 42 insertions, 66 deletions
diff --git a/text_recognizer/network/convnext/__init__.py b/text_recognizer/network/convnext/__init__.py
index dcff3fc..e69de29 100644
--- a/text_recognizer/network/convnext/__init__.py
+++ b/text_recognizer/network/convnext/__init__.py
@@ -1,7 +0,0 @@
-"""Convnext module."""
-from text_recognizer.network.convnext.attention import (
- Attention,
- FeedForward,
- TransformerBlock,
-)
-from text_recognizer.network.convnext.convnext import ConvNext
diff --git a/text_recognizer/network/convnext/convnext.py b/text_recognizer/network/convnext/convnext.py
index 6acf059..8eea9df 100644
--- a/text_recognizer/network/convnext/convnext.py
+++ b/text_recognizer/network/convnext/convnext.py
@@ -1,11 +1,27 @@
"""ConvNext module."""
from typing import Optional, Sequence
+import torch
from torch import Tensor, nn
-from text_recognizer.network.convnext.attention import TransformerBlock
-from text_recognizer.network.convnext.downsample import Downsample
-from text_recognizer.network.convnext.norm import LayerNorm
+from .transformer import Transformer
+from .downsample import Downsample
+from .norm import LayerNorm
+
+
+class GRN(nn.Module):
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.eps = eps
+ self.gamma = nn.Parameter(torch.zeros(dim, 1, 1))
+ self.bias = nn.Parameter(torch.zeros(dim, 1, 1))
+
+ def forward(self, x):
+ spatial_l2_norm = x.norm(p=2, dim=(2, 3), keepdim=True)
+ feat_norm = spatial_l2_norm / spatial_l2_norm.mean(dim=-1, keepdim=True).clamp(
+ min=self.eps
+ )
+ return x * feat_norm * self.gamma + self.bias + x
class ConvNextBlock(nn.Module):
@@ -13,14 +29,14 @@ class ConvNextBlock(nn.Module):
def __init__(self, dim: int, dim_out: int, mult: int) -> None:
super().__init__()
- self.ds_conv = nn.Conv2d(
- dim, dim, kernel_size=(7, 7), padding="same", groups=dim
- )
+ inner_dim = mult * dim_out
+ self.ds_conv = nn.Conv2d(dim, dim, kernel_size=7, padding="same", groups=dim)
self.net = nn.Sequential(
LayerNorm(dim),
- nn.Conv2d(dim, dim_out * mult, kernel_size=(3, 3), padding="same"),
+ nn.Conv2d(dim, inner_dim, kernel_size=3, stride=1, padding="same"),
nn.GELU(),
- nn.Conv2d(dim_out * mult, dim_out, kernel_size=(3, 3), padding="same"),
+ GRN(inner_dim),
+ nn.Conv2d(inner_dim, dim_out, kernel_size=3, stride=1, padding="same"),
)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
@@ -36,8 +52,7 @@ class ConvNext(nn.Module):
dim: int = 16,
dim_mults: Sequence[int] = (2, 4, 8),
depths: Sequence[int] = (3, 3, 6),
- downsampling_factors: Sequence[Sequence[int]] = ((2, 2), (2, 2), (2, 2)),
- attn: Optional[TransformerBlock] = None,
+ attn: Optional[Transformer] = None,
) -> None:
super().__init__()
dims = (dim, *map(lambda m: m * dim, dim_mults))
@@ -51,11 +66,10 @@ class ConvNext(nn.Module):
self.layers.append(
nn.ModuleList(
[
- ConvNextBlock(dim_in, dim_in, 2),
nn.ModuleList(
[ConvNextBlock(dim_in, dim_in, 2) for _ in range(depths[i])]
),
- Downsample(dim_in, dim_out, downsampling_factors[i]),
+ Downsample(dim_in, dim_out),
]
)
)
@@ -68,8 +82,7 @@ class ConvNext(nn.Module):
def forward(self, x: Tensor) -> Tensor:
x = self.stem(x)
- for init_block, blocks, down in self.layers:
- x = init_block(x)
+ for blocks, down in self.layers:
for fn in blocks:
x = fn(x)
x = down(x)
diff --git a/text_recognizer/network/convnext/downsample.py b/text_recognizer/network/convnext/downsample.py
index a8a0466..dcc14aa 100644
--- a/text_recognizer/network/convnext/downsample.py
+++ b/text_recognizer/network/convnext/downsample.py
@@ -1,6 +1,4 @@
"""Convnext downsample module."""
-from typing import Tuple
-
from einops.layers.torch import Rearrange
from torch import Tensor, nn
@@ -8,12 +6,11 @@ from torch import Tensor, nn
class Downsample(nn.Module):
"""Downsamples feature maps by patches."""
- def __init__(self, dim: int, dim_out: int, factors: Tuple[int, int]) -> None:
+ def __init__(self, dim: int, dim_out: int) -> None:
super().__init__()
- s1, s2 = factors
self.fn = nn.Sequential(
- Rearrange("b c (h s1) (w s2) -> b (c s1 s2) h w", s1=s1, s2=s2),
- nn.Conv2d(dim * s1 * s2, dim_out, 1),
+ Rearrange("b c (h s1) (w s2) -> b (c s1 s2) h w", s1=2, s2=2),
+ nn.Conv2d(dim * 4, dim_out, 1),
)
def forward(self, x: Tensor) -> Tensor:
diff --git a/text_recognizer/network/convnext/residual.py b/text_recognizer/network/convnext/residual.py
deleted file mode 100644
index dfc2847..0000000
--- a/text_recognizer/network/convnext/residual.py
+++ /dev/null
@@ -1,16 +0,0 @@
-"""Generic residual layer."""
-from typing import Callable
-
-from torch import Tensor, nn
-
-
-class Residual(nn.Module):
- """Residual layer."""
-
- def __init__(self, fn: Callable) -> None:
- super().__init__()
- self.fn = fn
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies residual fn."""
- return self.fn(x) + x
diff --git a/text_recognizer/network/convnext/attention.py b/text_recognizer/network/convnext/transformer.py
index 6bc9692..6c53c48 100644
--- a/text_recognizer/network/convnext/attention.py
+++ b/text_recognizer/network/convnext/transformer.py
@@ -1,29 +1,21 @@
"""Convolution self attention block."""
-import torch.nn.functional as F
from einops import rearrange
from torch import Tensor, einsum, nn
from text_recognizer.network.convnext.norm import LayerNorm
-from text_recognizer.network.convnext.residual import Residual
-
-
-def l2norm(t: Tensor) -> Tensor:
- return F.normalize(t, dim=-1)
class FeedForward(nn.Module):
def __init__(self, dim: int, mult: int = 4) -> None:
super().__init__()
inner_dim = int(dim * mult)
- self.fn = Residual(
- nn.Sequential(
- LayerNorm(dim),
- nn.Conv2d(dim, inner_dim, 1, bias=False),
- nn.GELU(),
- LayerNorm(inner_dim),
- nn.Conv2d(inner_dim, dim, 1, bias=False),
- )
+ self.fn = nn.Sequential(
+ LayerNorm(dim),
+ nn.Conv2d(dim, inner_dim, 1, bias=False),
+ nn.GELU(),
+ LayerNorm(inner_dim),
+ nn.Conv2d(inner_dim, dim, 1, bias=False),
)
def forward(self, x: Tensor) -> Tensor:
@@ -46,8 +38,6 @@ class Attention(nn.Module):
def forward(self, x: Tensor) -> Tensor:
h, w = x.shape[-2:]
- residual = x.clone()
-
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim=1)
@@ -56,24 +46,23 @@ class Attention(nn.Module):
(q, k, v),
)
- q, k = map(l2norm, (q, k))
-
- sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
+ q = q * self.scale
+ sim = einsum("b h i d, b h j d -> b h i j", q, k)
attn = sim.softmax(dim=-1)
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
- return self.to_out(out) + residual
+ return self.to_out(out)
-class TransformerBlock(nn.Module):
+class Transformer(nn.Module):
def __init__(self, attn: Attention, ff: FeedForward) -> None:
super().__init__()
self.attn = attn
self.ff = ff
def forward(self, x: Tensor) -> Tensor:
- x = self.attn(x)
- x = self.ff(x)
+ x = x + self.attn(x)
+ x = x + self.ff(x)
return x