From 4a4d5f2a2ee06069140b0d861018a70c63ad3d46 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 13 Sep 2022 21:58:26 +0200 Subject: Add convnext attention --- text_recognizer/networks/convnext/__init__.py | 5 ++ text_recognizer/networks/convnext/attention.py | 79 ++++++++++++++++++++++++++ text_recognizer/networks/convnext/convnext.py | 15 +++-- 3 files changed, 91 insertions(+), 8 deletions(-) (limited to 'text_recognizer') diff --git a/text_recognizer/networks/convnext/__init__.py b/text_recognizer/networks/convnext/__init__.py index 8d8470f..1743f09 100644 --- a/text_recognizer/networks/convnext/__init__.py +++ b/text_recognizer/networks/convnext/__init__.py @@ -1,2 +1,7 @@ """Convnext module.""" from text_recognizer.networks.convnext.convnext import ConvNext +from text_recognizer.networks.convnext.attention import ( + Attention, + FeedForward, + TransformerBlock, +) diff --git a/text_recognizer/networks/convnext/attention.py b/text_recognizer/networks/convnext/attention.py index e69de29..7f03436 100644 --- a/text_recognizer/networks/convnext/attention.py +++ b/text_recognizer/networks/convnext/attention.py @@ -0,0 +1,79 @@ +"""Convolution self attention block.""" + +from einops import reduce, rearrange +from torch import einsum, nn, Tensor +import torch.nn.functional as F + +from text_recognizer.networks.convnext.norm import LayerNorm +from text_recognizer.networks.convnext.residual import Residual + + +def l2norm(t: Tensor) -> Tensor: + return F.normalize(t, dim=-1) + + +class FeedForward(nn.Module): + def __init__(self, dim: int, mult: int = 4) -> None: + super().__init__() + inner_dim = int(dim * mult) + self.fn = Residual( + nn.Sequential( + LayerNorm(dim), + nn.Conv2d(dim, inner_dim, 1, bias=False), + nn.GELU(), + LayerNorm(inner_dim), + nn.Conv2d(inner_dim, dim, 1, bias=False), + ) + ) + + def forward(self, x: Tensor) -> Tensor: + return self.fn(x) + + +class Attention(nn.Module): + def __init__( + self, dim: int, heads: int = 4, dim_head: int = 64, scale: int = 8 + ) -> None: + super().__init__() + self.scale = scale + self.heads = heads + inner_dim = heads * dim_head + self.norm = LayerNorm(dim) + + self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(inner_dim, dim, 1, bias=False) + + def forward(self, x: Tensor) -> Tensor: + h, w = x.shape[-2:] + + residual = x.clone() + + x = self.norm(x) + + q, k, v = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map( + lambda t: rearrange(t, "b (h c) ... -> b h (...) c", h=self.heads), + (q, k, v), + ) + + q, k = map(l2norm, (q, k)) + + sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale + attn = sim.softmax(dim=-1) + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) + return self.to_out(out) + residual + + +class TransformerBlock(nn.Module): + def __init__(self, attn: Attention, ff: FeedForward) -> None: + super().__init__() + self.attn = attn + self.ff = ff + + def forward(self, x: Tensor) -> Tensor: + x = self.attn(x) + x = self.ff(x) + return x diff --git a/text_recognizer/networks/convnext/convnext.py b/text_recognizer/networks/convnext/convnext.py index a4556a0..68de81a 100644 --- a/text_recognizer/networks/convnext/convnext.py +++ b/text_recognizer/networks/convnext/convnext.py @@ -1,13 +1,9 @@ -from typing import Sequence +from typing import Optional, Sequence -from einops import reduce, rearrange -from einops.layers.torch import Rearrange -import torch -from torch import einsum, nn, Tensor -import torch.nn.functional as F +from text_recognizer.networks.convnext.attention import TransformerBlock +from torch import nn, Tensor from text_recognizer.networks.convnext.downsample import Downsample -from text_recognizer.networks.convnext.residual import Residual from text_recognizer.networks.convnext.norm import LayerNorm @@ -38,9 +34,11 @@ class ConvNext(nn.Module): dim_mults: Sequence[int] = (2, 4, 8), depths: Sequence[int] = (3, 3, 6), downsampling_factors: Sequence[Sequence[int]] = ((2, 2), (2, 2), (2, 2)), + attn: Optional[TransformerBlock] = None, ) -> None: super().__init__() dims = (dim, *map(lambda m: m * dim, dim_mults)) + self.attn = attn self.out_channels = dims[-1] self.stem = nn.Conv2d(1, dims[0], kernel_size=(7, 7), padding="same") self.layers = nn.ModuleList([]) @@ -65,11 +63,12 @@ class ConvNext(nn.Module): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: x = self.stem(x) for init_block, blocks, down in self.layers: x = init_block(x) for fn in blocks: x = fn(x) x = down(x) + x = self.attn(x) return self.norm(x) -- cgit v1.2.3-70-g09d2