summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/nystromer/nystromer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/nystromer/nystromer.py')
-rw-r--r--text_recognizer/networks/transformer/nystromer/nystromer.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/nystromer/nystromer.py b/text_recognizer/networks/transformer/nystromer/nystromer.py
new file mode 100644
index 0000000..0283d69
--- /dev/null
+++ b/text_recognizer/networks/transformer/nystromer/nystromer.py
@@ -0,0 +1,57 @@
+"""Nyströmer encoder.
+
+Stolen from:
+ https://github.com/lucidrains/nystrom-attention/blob/main/nystrom_attention/nystrom_attention.py
+
+"""
+from typing import Optional
+
+from torch import nn, Tensor
+
+from text_recognizer.networks.transformer.mlp import FeedForward
+from text_recognizer.networks.transformer.norm import PreNorm
+from text_recognizer.networks.transformer.nystromer.attention import NystromAttention
+
+
+class Nystromer(nn.Module):
+ def __init__(
+ self,
+ *,
+ dim: int,
+ depth: int,
+ dim_head: int = 64,
+ num_heads: int = 8,
+ num_landmarks: int = 256,
+ inverse_iter: int = 6,
+ residual: bool = True,
+ residual_conv_kernel: int = 33,
+ dropout_rate: float = 0.0,
+ ):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [
+ [
+ PreNorm(
+ dim,
+ NystromAttention(
+ dim=dim,
+ dim_head=dim_head,
+ num_heads=num_heads,
+ num_landmarks=num_landmarks,
+ inverse_iter=inverse_iter,
+ residual=residual,
+ residual_conv_kernel=residual_conv_kernel,
+ dropout_rate=dropout_rate,
+ ),
+ ),
+ PreNorm(dim, FeedForward(dim=dim, dropout_rate=dropout_rate)),
+ ]
+ ]
+ for _ in range(depth)
+ )
+
+ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
+ for attn, ff in self.layers:
+ x = attn(x, mask=mask) + x
+ x = ff(x) + x
+ return x