diff options
Diffstat (limited to 'text_recognizer/networks/convnext')
-rw-r--r-- | text_recognizer/networks/convnext/convnext.py | 7 | ||||
-rw-r--r-- | text_recognizer/networks/convnext/downsample.py | 4 | ||||
-rw-r--r-- | text_recognizer/networks/convnext/norm.py | 3 | ||||
-rw-r--r-- | text_recognizer/networks/convnext/residual.py | 3 |
4 files changed, 15 insertions, 2 deletions
diff --git a/text_recognizer/networks/convnext/convnext.py b/text_recognizer/networks/convnext/convnext.py index b4dfad7..9419a15 100644 --- a/text_recognizer/networks/convnext/convnext.py +++ b/text_recognizer/networks/convnext/convnext.py @@ -1,3 +1,4 @@ +"""ConvNext module.""" from typing import Optional, Sequence from torch import Tensor, nn @@ -8,7 +9,9 @@ from text_recognizer.networks.convnext.norm import LayerNorm class ConvNextBlock(nn.Module): - def __init__(self, dim, dim_out, mult): + """ConvNext block.""" + + def __init__(self, dim: int, dim_out: int, mult: int) -> None: super().__init__() self.ds_conv = nn.Conv2d( dim, dim, kernel_size=(7, 7), padding="same", groups=dim @@ -21,7 +24,7 @@ class ConvNextBlock(nn.Module): ) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: h = self.ds_conv(x) h = self.net(h) return h + self.res_conv(x) diff --git a/text_recognizer/networks/convnext/downsample.py b/text_recognizer/networks/convnext/downsample.py index c28ecca..a8a0466 100644 --- a/text_recognizer/networks/convnext/downsample.py +++ b/text_recognizer/networks/convnext/downsample.py @@ -1,3 +1,4 @@ +"""Convnext downsample module.""" from typing import Tuple from einops.layers.torch import Rearrange @@ -5,6 +6,8 @@ from torch import Tensor, nn class Downsample(nn.Module): + """Downsamples feature maps by patches.""" + def __init__(self, dim: int, dim_out: int, factors: Tuple[int, int]) -> None: super().__init__() s1, s2 = factors @@ -14,4 +17,5 @@ class Downsample(nn.Module): ) def forward(self, x: Tensor) -> Tensor: + """Applies patch function.""" return self.fn(x) diff --git a/text_recognizer/networks/convnext/norm.py b/text_recognizer/networks/convnext/norm.py index 23cf07a..3355de9 100644 --- a/text_recognizer/networks/convnext/norm.py +++ b/text_recognizer/networks/convnext/norm.py @@ -4,11 +4,14 @@ from torch import Tensor, nn class LayerNorm(nn.Module): + """Layer norm for convolutions.""" + def __init__(self, dim: int) -> None: super().__init__() self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1)) def forward(self, x: Tensor) -> Tensor: + """Applies layer norm.""" eps = 1e-5 if x.dtype == torch.float32 else 1e-3 var = torch.var(x, dim=1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=1, keepdim=True) diff --git a/text_recognizer/networks/convnext/residual.py b/text_recognizer/networks/convnext/residual.py index 8e76ae9..dfc2847 100644 --- a/text_recognizer/networks/convnext/residual.py +++ b/text_recognizer/networks/convnext/residual.py @@ -5,9 +5,12 @@ from torch import Tensor, nn class Residual(nn.Module): + """Residual layer.""" + def __init__(self, fn: Callable) -> None: super().__init__() self.fn = fn def forward(self, x: Tensor) -> Tensor: + """Applies residual fn.""" return self.fn(x) + x |