diff options
| -rw-r--r-- | text_recognizer/networks/convnext/convnext.py | 6 | 
1 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/networks/convnext/convnext.py b/text_recognizer/networks/convnext/convnext.py index 68de81a..308c009 100644 --- a/text_recognizer/networks/convnext/convnext.py +++ b/text_recognizer/networks/convnext/convnext.py @@ -1,8 +1,8 @@  from typing import Optional, Sequence -from text_recognizer.networks.convnext.attention import TransformerBlock -from torch import nn, Tensor +from torch import Tensor, nn +from text_recognizer.networks.convnext.attention import TransformerBlock  from text_recognizer.networks.convnext.downsample import Downsample  from text_recognizer.networks.convnext.norm import LayerNorm @@ -70,5 +70,5 @@ class ConvNext(nn.Module):              for fn in blocks:                  x = fn(x)              x = down(x) -        x = self.attn(x) +        x = self.attn(x) if self.attn is not None else x          return self.norm(x)  |