"""Nyströmer encoder. Stolen from: https://github.com/lucidrains/nystrom-attention/blob/main/nystrom_attention/nystrom_attention.py """ from typing import Optional from torch import nn, Tensor from text_recognizer.networks.transformer.mlp import FeedForward from text_recognizer.networks.transformer.norm import PreNorm from text_recognizer.networks.transformer.nystromer.attention import NystromAttention class Nystromer(nn.Module): def __init__( self, *, dim: int, depth: int, dim_head: int = 64, num_heads: int = 8, num_landmarks: int = 256, inverse_iter: int = 6, residual: bool = True, residual_conv_kernel: int = 33, dropout_rate: float = 0.0, ): super().__init__() self.layers = nn.ModuleList( [ [ PreNorm( dim, NystromAttention( dim=dim, dim_head=dim_head, num_heads=num_heads, num_landmarks=num_landmarks, inverse_iter=inverse_iter, residual=residual, residual_conv_kernel=residual_conv_kernel, dropout_rate=dropout_rate, ), ), PreNorm(dim, FeedForward(dim=dim, dropout_rate=dropout_rate)), ] ] for _ in range(depth) ) def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: for attn, ff in self.layers: x = attn(x, mask=mask) + x x = ff(x) + x return x