summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/convnext
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/convnext')
-rw-r--r--text_recognizer/networks/convnext/convnext.py7
-rw-r--r--text_recognizer/networks/convnext/downsample.py4
-rw-r--r--text_recognizer/networks/convnext/norm.py3
-rw-r--r--text_recognizer/networks/convnext/residual.py3
4 files changed, 15 insertions, 2 deletions
diff --git a/text_recognizer/networks/convnext/convnext.py b/text_recognizer/networks/convnext/convnext.py
index b4dfad7..9419a15 100644
--- a/text_recognizer/networks/convnext/convnext.py
+++ b/text_recognizer/networks/convnext/convnext.py
@@ -1,3 +1,4 @@
+"""ConvNext module."""
from typing import Optional, Sequence
from torch import Tensor, nn
@@ -8,7 +9,9 @@ from text_recognizer.networks.convnext.norm import LayerNorm
class ConvNextBlock(nn.Module):
- def __init__(self, dim, dim_out, mult):
+ """ConvNext block."""
+
+ def __init__(self, dim: int, dim_out: int, mult: int) -> None:
super().__init__()
self.ds_conv = nn.Conv2d(
dim, dim, kernel_size=(7, 7), padding="same", groups=dim
@@ -21,7 +24,7 @@ class ConvNextBlock(nn.Module):
)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
- def forward(self, x):
+ def forward(self, x: Tensor) -> Tensor:
h = self.ds_conv(x)
h = self.net(h)
return h + self.res_conv(x)
diff --git a/text_recognizer/networks/convnext/downsample.py b/text_recognizer/networks/convnext/downsample.py
index c28ecca..a8a0466 100644
--- a/text_recognizer/networks/convnext/downsample.py
+++ b/text_recognizer/networks/convnext/downsample.py
@@ -1,3 +1,4 @@
+"""Convnext downsample module."""
from typing import Tuple
from einops.layers.torch import Rearrange
@@ -5,6 +6,8 @@ from torch import Tensor, nn
class Downsample(nn.Module):
+ """Downsamples feature maps by patches."""
+
def __init__(self, dim: int, dim_out: int, factors: Tuple[int, int]) -> None:
super().__init__()
s1, s2 = factors
@@ -14,4 +17,5 @@ class Downsample(nn.Module):
)
def forward(self, x: Tensor) -> Tensor:
+ """Applies patch function."""
return self.fn(x)
diff --git a/text_recognizer/networks/convnext/norm.py b/text_recognizer/networks/convnext/norm.py
index 23cf07a..3355de9 100644
--- a/text_recognizer/networks/convnext/norm.py
+++ b/text_recognizer/networks/convnext/norm.py
@@ -4,11 +4,14 @@ from torch import Tensor, nn
class LayerNorm(nn.Module):
+ """Layer norm for convolutions."""
+
def __init__(self, dim: int) -> None:
super().__init__()
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x: Tensor) -> Tensor:
+ """Applies layer norm."""
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=1, keepdim=True)
diff --git a/text_recognizer/networks/convnext/residual.py b/text_recognizer/networks/convnext/residual.py
index 8e76ae9..dfc2847 100644
--- a/text_recognizer/networks/convnext/residual.py
+++ b/text_recognizer/networks/convnext/residual.py
@@ -5,9 +5,12 @@ from torch import Tensor, nn
class Residual(nn.Module):
+ """Residual layer."""
+
def __init__(self, fn: Callable) -> None:
super().__init__()
self.fn = fn
def forward(self, x: Tensor) -> Tensor:
+ """Applies residual fn."""
return self.fn(x) + x