summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/convnext/convnext.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/convnext/convnext.py')
-rw-r--r--text_recognizer/networks/convnext/convnext.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/text_recognizer/networks/convnext/convnext.py b/text_recognizer/networks/convnext/convnext.py
index b4dfad7..9419a15 100644
--- a/text_recognizer/networks/convnext/convnext.py
+++ b/text_recognizer/networks/convnext/convnext.py
@@ -1,3 +1,4 @@
+"""ConvNext module."""
from typing import Optional, Sequence
from torch import Tensor, nn
@@ -8,7 +9,9 @@ from text_recognizer.networks.convnext.norm import LayerNorm
class ConvNextBlock(nn.Module):
- def __init__(self, dim, dim_out, mult):
+ """ConvNext block."""
+
+ def __init__(self, dim: int, dim_out: int, mult: int) -> None:
super().__init__()
self.ds_conv = nn.Conv2d(
dim, dim, kernel_size=(7, 7), padding="same", groups=dim
@@ -21,7 +24,7 @@ class ConvNextBlock(nn.Module):
)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
- def forward(self, x):
+ def forward(self, x: Tensor) -> Tensor:
h = self.ds_conv(x)
h = self.net(h)
return h + self.res_conv(x)