summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/convnext
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-14 00:53:52 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-14 00:53:52 +0200
commite06a94f0df3c19ecd23024162b3697bd1e12c8a9 (patch)
tree41c463a89767a2cab1652effc7bfa112340e9177 /text_recognizer/networks/convnext
parent46c8fb98581b1a9de1ab07d1b0e9ac52f382b00c (diff)
Fix convnext
Diffstat (limited to 'text_recognizer/networks/convnext')
-rw-r--r--text_recognizer/networks/convnext/convnext.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/networks/convnext/convnext.py b/text_recognizer/networks/convnext/convnext.py
index 68de81a..308c009 100644
--- a/text_recognizer/networks/convnext/convnext.py
+++ b/text_recognizer/networks/convnext/convnext.py
@@ -1,8 +1,8 @@
from typing import Optional, Sequence
-from text_recognizer.networks.convnext.attention import TransformerBlock
-from torch import nn, Tensor
+from torch import Tensor, nn
+from text_recognizer.networks.convnext.attention import TransformerBlock
from text_recognizer.networks.convnext.downsample import Downsample
from text_recognizer.networks.convnext.norm import LayerNorm
@@ -70,5 +70,5 @@ class ConvNext(nn.Module):
for fn in blocks:
x = fn(x)
x = down(x)
- x = self.attn(x)
+ x = self.attn(x) if self.attn is not None else x
return self.norm(x)