From 53450493e0a13d835fd1d2457c49a9d60bee0e18 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 4 May 2021 23:11:44 +0200 Subject: Nyströmer implemented but not tested MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../networks/transformer/nystromer/nystromer.py | 57 ++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 text_recognizer/networks/transformer/nystromer/nystromer.py (limited to 'text_recognizer/networks/transformer/nystromer/nystromer.py') diff --git a/text_recognizer/networks/transformer/nystromer/nystromer.py b/text_recognizer/networks/transformer/nystromer/nystromer.py new file mode 100644 index 0000000..0283d69 --- /dev/null +++ b/text_recognizer/networks/transformer/nystromer/nystromer.py @@ -0,0 +1,57 @@ +"""Nyströmer encoder. + +Stolen from: + https://github.com/lucidrains/nystrom-attention/blob/main/nystrom_attention/nystrom_attention.py + +""" +from typing import Optional + +from torch import nn, Tensor + +from text_recognizer.networks.transformer.mlp import FeedForward +from text_recognizer.networks.transformer.norm import PreNorm +from text_recognizer.networks.transformer.nystromer.attention import NystromAttention + + +class Nystromer(nn.Module): + def __init__( + self, + *, + dim: int, + depth: int, + dim_head: int = 64, + num_heads: int = 8, + num_landmarks: int = 256, + inverse_iter: int = 6, + residual: bool = True, + residual_conv_kernel: int = 33, + dropout_rate: float = 0.0, + ): + super().__init__() + self.layers = nn.ModuleList( + [ + [ + PreNorm( + dim, + NystromAttention( + dim=dim, + dim_head=dim_head, + num_heads=num_heads, + num_landmarks=num_landmarks, + inverse_iter=inverse_iter, + residual=residual, + residual_conv_kernel=residual_conv_kernel, + dropout_rate=dropout_rate, + ), + ), + PreNorm(dim, FeedForward(dim=dim, dropout_rate=dropout_rate)), + ] + ] + for _ in range(depth) + ) + + def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: + for attn, ff in self.layers: + x = attn(x, mask=mask) + x + x = ff(x) + x + return x -- cgit v1.2.3-70-g09d2