diff options
-rw-r--r-- | notebooks/00-scratch-pad.ipynb (renamed from notebooks/00-testing-stuff-out.ipynb) | 51 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/nystromer/nystromer.py | 34 |
2 files changed, 69 insertions, 16 deletions
diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-scratch-pad.ipynb index 12c5145..d50fd59 100644 --- a/notebooks/00-testing-stuff-out.ipynb +++ b/notebooks/00-scratch-pad.ipynb @@ -27,6 +27,57 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ModuleList(\n", + " (0): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (1): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (2): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (3): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (4): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (5): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (6): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (7): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (8): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (9): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nn.ModuleList([nn.ModuleList([nn.Linear(10, 10)]) for _ in range(10)])" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, "outputs": [], "source": [ "from omegaconf import OmegaConf" diff --git a/text_recognizer/networks/transformer/nystromer/nystromer.py b/text_recognizer/networks/transformer/nystromer/nystromer.py index 0283d69..7cc889e 100644 --- a/text_recognizer/networks/transformer/nystromer/nystromer.py +++ b/text_recognizer/networks/transformer/nystromer/nystromer.py @@ -30,24 +30,26 @@ class Nystromer(nn.Module): super().__init__() self.layers = nn.ModuleList( [ - [ - PreNorm( - dim, - NystromAttention( - dim=dim, - dim_head=dim_head, - num_heads=num_heads, - num_landmarks=num_landmarks, - inverse_iter=inverse_iter, - residual=residual, - residual_conv_kernel=residual_conv_kernel, - dropout_rate=dropout_rate, + nn.ModuleList( + [ + PreNorm( + dim, + NystromAttention( + dim=dim, + dim_head=dim_head, + num_heads=num_heads, + num_landmarks=num_landmarks, + inverse_iter=inverse_iter, + residual=residual, + residual_conv_kernel=residual_conv_kernel, + dropout_rate=dropout_rate, + ), ), - ), - PreNorm(dim, FeedForward(dim=dim, dropout_rate=dropout_rate)), - ] + PreNorm(dim, FeedForward(dim=dim, dropout_rate=dropout_rate)), + ] + ) + for _ in range(depth) ] - for _ in range(depth) ) def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: |