From 63a69eb4015f203ff53f9fd0cbed10abbf3bbd87 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 23 Jul 2021 15:48:22 +0200 Subject: Remove nystromer --- .../networks/transformer/nystromer/nystromer.py | 64 ---------------------- 1 file changed, 64 deletions(-) delete mode 100644 text_recognizer/networks/transformer/nystromer/nystromer.py (limited to 'text_recognizer/networks/transformer/nystromer/nystromer.py') diff --git a/text_recognizer/networks/transformer/nystromer/nystromer.py b/text_recognizer/networks/transformer/nystromer/nystromer.py deleted file mode 100644 index 2113f1f..0000000 --- a/text_recognizer/networks/transformer/nystromer/nystromer.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Nyströmer encoder. - -Stolen from: - https://github.com/lucidrains/nystrom-attention/blob/main/nystrom_attention/nystrom_attention.py - -""" -from typing import Optional - -from torch import nn, Tensor - -from text_recognizer.networks.transformer.mlp import FeedForward -from text_recognizer.networks.transformer.norm import PreNorm -from text_recognizer.networks.transformer.nystromer.attention import NystromAttention - - -class Nystromer(nn.Module): - def __init__( - self, - *, - dim: int, - depth: int, - dim_head: int = 64, - num_heads: int = 8, - num_landmarks: int = 256, - inverse_iter: int = 6, - residual: bool = True, - residual_conv_kernel: int = 33, - dropout_rate: float = 0.0, - glu: bool = True, - ) -> None: - super().__init__() - self.dim = dim - self.layers = nn.ModuleList( - [ - nn.ModuleList( - [ - PreNorm( - dim, - NystromAttention( - dim=dim, - dim_head=dim_head, - num_heads=num_heads, - num_landmarks=num_landmarks, - inverse_iter=inverse_iter, - residual=residual, - residual_conv_kernel=residual_conv_kernel, - dropout_rate=dropout_rate, - ), - ), - PreNorm( - dim, - FeedForward(dim=dim, glu=glu, dropout_rate=dropout_rate), - ), - ] - ) - for _ in range(depth) - ] - ) - - def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: - for attn, ff in self.layers: - x = attn(x, mask=mask) + x - x = ff(x) + x - return x -- cgit v1.2.3-70-g09d2