From 862182eaf4e0cc88e26e53609c67d9b98451f14c Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 27 Sep 2022 00:10:47 +0200 Subject: Update convnext --- text_recognizer/networks/convnext/convnext.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'text_recognizer/networks/convnext') diff --git a/text_recognizer/networks/convnext/convnext.py b/text_recognizer/networks/convnext/convnext.py index 308c009..b4dfad7 100644 --- a/text_recognizer/networks/convnext/convnext.py +++ b/text_recognizer/networks/convnext/convnext.py @@ -38,9 +38,9 @@ class ConvNext(nn.Module): ) -> None: super().__init__() dims = (dim, *map(lambda m: m * dim, dim_mults)) - self.attn = attn + self.attn = attn if attn is not None else nn.Identity() self.out_channels = dims[-1] - self.stem = nn.Conv2d(1, dims[0], kernel_size=(7, 7), padding="same") + self.stem = nn.Conv2d(1, dims[0], kernel_size=7, padding="same") self.layers = nn.ModuleList([]) for i in range(len(dims) - 1): @@ -70,5 +70,5 @@ class ConvNext(nn.Module): for fn in blocks: x = fn(x) x = down(x) - x = self.attn(x) if self.attn is not None else x + x = self.attn(x) return self.norm(x) -- cgit v1.2.3-70-g09d2