diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/convnext/convnext.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/networks/convnext/convnext.py b/text_recognizer/networks/convnext/convnext.py index 308c009..b4dfad7 100644 --- a/text_recognizer/networks/convnext/convnext.py +++ b/text_recognizer/networks/convnext/convnext.py @@ -38,9 +38,9 @@ class ConvNext(nn.Module): ) -> None: super().__init__() dims = (dim, *map(lambda m: m * dim, dim_mults)) - self.attn = attn + self.attn = attn if attn is not None else nn.Identity() self.out_channels = dims[-1] - self.stem = nn.Conv2d(1, dims[0], kernel_size=(7, 7), padding="same") + self.stem = nn.Conv2d(1, dims[0], kernel_size=7, padding="same") self.layers = nn.ModuleList([]) for i in range(len(dims) - 1): @@ -70,5 +70,5 @@ class ConvNext(nn.Module): for fn in blocks: x = fn(x) x = down(x) - x = self.attn(x) if self.attn is not None else x + x = self.attn(x) return self.norm(x) |