From 4a4d5f2a2ee06069140b0d861018a70c63ad3d46 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 13 Sep 2022 21:58:26 +0200 Subject: Add convnext attention --- text_recognizer/networks/convnext/convnext.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) (limited to 'text_recognizer/networks/convnext/convnext.py') diff --git a/text_recognizer/networks/convnext/convnext.py b/text_recognizer/networks/convnext/convnext.py index a4556a0..68de81a 100644 --- a/text_recognizer/networks/convnext/convnext.py +++ b/text_recognizer/networks/convnext/convnext.py @@ -1,13 +1,9 @@ -from typing import Sequence +from typing import Optional, Sequence -from einops import reduce, rearrange -from einops.layers.torch import Rearrange -import torch -from torch import einsum, nn, Tensor -import torch.nn.functional as F +from text_recognizer.networks.convnext.attention import TransformerBlock +from torch import nn, Tensor from text_recognizer.networks.convnext.downsample import Downsample -from text_recognizer.networks.convnext.residual import Residual from text_recognizer.networks.convnext.norm import LayerNorm @@ -38,9 +34,11 @@ class ConvNext(nn.Module): dim_mults: Sequence[int] = (2, 4, 8), depths: Sequence[int] = (3, 3, 6), downsampling_factors: Sequence[Sequence[int]] = ((2, 2), (2, 2), (2, 2)), + attn: Optional[TransformerBlock] = None, ) -> None: super().__init__() dims = (dim, *map(lambda m: m * dim, dim_mults)) + self.attn = attn self.out_channels = dims[-1] self.stem = nn.Conv2d(1, dims[0], kernel_size=(7, 7), padding="same") self.layers = nn.ModuleList([]) @@ -65,11 +63,12 @@ class ConvNext(nn.Module): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: x = self.stem(x) for init_block, blocks, down in self.layers: x = init_block(x) for fn in blocks: x = fn(x) x = down(x) + x = self.attn(x) return self.norm(x) -- cgit v1.2.3-70-g09d2