summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/networks/transformer/nystromer/nystromer.py34
1 files changed, 18 insertions, 16 deletions
diff --git a/text_recognizer/networks/transformer/nystromer/nystromer.py b/text_recognizer/networks/transformer/nystromer/nystromer.py
index 0283d69..7cc889e 100644
--- a/text_recognizer/networks/transformer/nystromer/nystromer.py
+++ b/text_recognizer/networks/transformer/nystromer/nystromer.py
@@ -30,24 +30,26 @@ class Nystromer(nn.Module):
super().__init__()
self.layers = nn.ModuleList(
[
- [
- PreNorm(
- dim,
- NystromAttention(
- dim=dim,
- dim_head=dim_head,
- num_heads=num_heads,
- num_landmarks=num_landmarks,
- inverse_iter=inverse_iter,
- residual=residual,
- residual_conv_kernel=residual_conv_kernel,
- dropout_rate=dropout_rate,
+ nn.ModuleList(
+ [
+ PreNorm(
+ dim,
+ NystromAttention(
+ dim=dim,
+ dim_head=dim_head,
+ num_heads=num_heads,
+ num_landmarks=num_landmarks,
+ inverse_iter=inverse_iter,
+ residual=residual,
+ residual_conv_kernel=residual_conv_kernel,
+ dropout_rate=dropout_rate,
+ ),
),
- ),
- PreNorm(dim, FeedForward(dim=dim, dropout_rate=dropout_rate)),
- ]
+ PreNorm(dim, FeedForward(dim=dim, dropout_rate=dropout_rate)),
+ ]
+ )
+ for _ in range(depth)
]
- for _ in range(depth)
)
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: