diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-04 23:24:00 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-04 23:24:00 +0200 |
commit | 4defc734b681071e19dd86404abd416d24330b9a (patch) | |
tree | 2447a7bc3fada64d1b45ac73346f816f9e90849c /text_recognizer | |
parent | 53450493e0a13d835fd1d2457c49a9d60bee0e18 (diff) |
Bug fix
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/transformer/nystromer/nystromer.py | 34 |
1 files changed, 18 insertions, 16 deletions
diff --git a/text_recognizer/networks/transformer/nystromer/nystromer.py b/text_recognizer/networks/transformer/nystromer/nystromer.py index 0283d69..7cc889e 100644 --- a/text_recognizer/networks/transformer/nystromer/nystromer.py +++ b/text_recognizer/networks/transformer/nystromer/nystromer.py @@ -30,24 +30,26 @@ class Nystromer(nn.Module): super().__init__() self.layers = nn.ModuleList( [ - [ - PreNorm( - dim, - NystromAttention( - dim=dim, - dim_head=dim_head, - num_heads=num_heads, - num_landmarks=num_landmarks, - inverse_iter=inverse_iter, - residual=residual, - residual_conv_kernel=residual_conv_kernel, - dropout_rate=dropout_rate, + nn.ModuleList( + [ + PreNorm( + dim, + NystromAttention( + dim=dim, + dim_head=dim_head, + num_heads=num_heads, + num_landmarks=num_landmarks, + inverse_iter=inverse_iter, + residual=residual, + residual_conv_kernel=residual_conv_kernel, + dropout_rate=dropout_rate, + ), ), - ), - PreNorm(dim, FeedForward(dim=dim, dropout_rate=dropout_rate)), - ] + PreNorm(dim, FeedForward(dim=dim, dropout_rate=dropout_rate)), + ] + ) + for _ in range(depth) ] - for _ in range(depth) ) def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: |