diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 00:10:47 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 00:10:47 +0200 |
commit | 862182eaf4e0cc88e26e53609c67d9b98451f14c (patch) | |
tree | 478ebaaa9f8c249209b568eba7948feffbb29180 /text_recognizer/networks/convnext | |
parent | 0540237d794ab2071764dc74e4d3bb52f5bf44be (diff) |
Update convnext
Diffstat (limited to 'text_recognizer/networks/convnext')
-rw-r--r-- | text_recognizer/networks/convnext/convnext.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/networks/convnext/convnext.py b/text_recognizer/networks/convnext/convnext.py index 308c009..b4dfad7 100644 --- a/text_recognizer/networks/convnext/convnext.py +++ b/text_recognizer/networks/convnext/convnext.py @@ -38,9 +38,9 @@ class ConvNext(nn.Module): ) -> None: super().__init__() dims = (dim, *map(lambda m: m * dim, dim_mults)) - self.attn = attn + self.attn = attn if attn is not None else nn.Identity() self.out_channels = dims[-1] - self.stem = nn.Conv2d(1, dims[0], kernel_size=(7, 7), padding="same") + self.stem = nn.Conv2d(1, dims[0], kernel_size=7, padding="same") self.layers = nn.ModuleList([]) for i in range(len(dims) - 1): @@ -70,5 +70,5 @@ class ConvNext(nn.Module): for fn in blocks: x = fn(x) x = down(x) - x = self.attn(x) if self.attn is not None else x + x = self.attn(x) return self.norm(x) |