summaryrefslogtreecommitdiff
path: root/text_recognizer/network/convnext/convnext.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/network/convnext/convnext.py')
-rw-r--r--text_recognizer/network/convnext/convnext.py41
1 files changed, 27 insertions, 14 deletions
diff --git a/text_recognizer/network/convnext/convnext.py b/text_recognizer/network/convnext/convnext.py
index 6acf059..8eea9df 100644
--- a/text_recognizer/network/convnext/convnext.py
+++ b/text_recognizer/network/convnext/convnext.py
@@ -1,11 +1,27 @@
"""ConvNext module."""
from typing import Optional, Sequence
+import torch
from torch import Tensor, nn
-from text_recognizer.network.convnext.attention import TransformerBlock
-from text_recognizer.network.convnext.downsample import Downsample
-from text_recognizer.network.convnext.norm import LayerNorm
+from .transformer import Transformer
+from .downsample import Downsample
+from .norm import LayerNorm
+
+
+class GRN(nn.Module):
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.eps = eps
+ self.gamma = nn.Parameter(torch.zeros(dim, 1, 1))
+ self.bias = nn.Parameter(torch.zeros(dim, 1, 1))
+
+ def forward(self, x):
+ spatial_l2_norm = x.norm(p=2, dim=(2, 3), keepdim=True)
+ feat_norm = spatial_l2_norm / spatial_l2_norm.mean(dim=-1, keepdim=True).clamp(
+ min=self.eps
+ )
+ return x * feat_norm * self.gamma + self.bias + x
class ConvNextBlock(nn.Module):
@@ -13,14 +29,14 @@ class ConvNextBlock(nn.Module):
def __init__(self, dim: int, dim_out: int, mult: int) -> None:
super().__init__()
- self.ds_conv = nn.Conv2d(
- dim, dim, kernel_size=(7, 7), padding="same", groups=dim
- )
+ inner_dim = mult * dim_out
+ self.ds_conv = nn.Conv2d(dim, dim, kernel_size=7, padding="same", groups=dim)
self.net = nn.Sequential(
LayerNorm(dim),
- nn.Conv2d(dim, dim_out * mult, kernel_size=(3, 3), padding="same"),
+ nn.Conv2d(dim, inner_dim, kernel_size=3, stride=1, padding="same"),
nn.GELU(),
- nn.Conv2d(dim_out * mult, dim_out, kernel_size=(3, 3), padding="same"),
+ GRN(inner_dim),
+ nn.Conv2d(inner_dim, dim_out, kernel_size=3, stride=1, padding="same"),
)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
@@ -36,8 +52,7 @@ class ConvNext(nn.Module):
dim: int = 16,
dim_mults: Sequence[int] = (2, 4, 8),
depths: Sequence[int] = (3, 3, 6),
- downsampling_factors: Sequence[Sequence[int]] = ((2, 2), (2, 2), (2, 2)),
- attn: Optional[TransformerBlock] = None,
+ attn: Optional[Transformer] = None,
) -> None:
super().__init__()
dims = (dim, *map(lambda m: m * dim, dim_mults))
@@ -51,11 +66,10 @@ class ConvNext(nn.Module):
self.layers.append(
nn.ModuleList(
[
- ConvNextBlock(dim_in, dim_in, 2),
nn.ModuleList(
[ConvNextBlock(dim_in, dim_in, 2) for _ in range(depths[i])]
),
- Downsample(dim_in, dim_out, downsampling_factors[i]),
+ Downsample(dim_in, dim_out),
]
)
)
@@ -68,8 +82,7 @@ class ConvNext(nn.Module):
def forward(self, x: Tensor) -> Tensor:
x = self.stem(x)
- for init_block, blocks, down in self.layers:
- x = init_block(x)
+ for blocks, down in self.layers:
for fn in blocks:
x = fn(x)
x = down(x)