summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/convnext/convnext.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/networks/convnext/convnext.py b/text_recognizer/networks/convnext/convnext.py
index 308c009..b4dfad7 100644
--- a/text_recognizer/networks/convnext/convnext.py
+++ b/text_recognizer/networks/convnext/convnext.py
@@ -38,9 +38,9 @@ class ConvNext(nn.Module):
) -> None:
super().__init__()
dims = (dim, *map(lambda m: m * dim, dim_mults))
- self.attn = attn
+ self.attn = attn if attn is not None else nn.Identity()
self.out_channels = dims[-1]
- self.stem = nn.Conv2d(1, dims[0], kernel_size=(7, 7), padding="same")
+ self.stem = nn.Conv2d(1, dims[0], kernel_size=7, padding="same")
self.layers = nn.ModuleList([])
for i in range(len(dims) - 1):
@@ -70,5 +70,5 @@ class ConvNext(nn.Module):
for fn in blocks:
x = fn(x)
x = down(x)
- x = self.attn(x) if self.attn is not None else x
+ x = self.attn(x)
return self.norm(x)