summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-04 23:24:00 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-04 23:24:00 +0200
commit4defc734b681071e19dd86404abd416d24330b9a (patch)
tree2447a7bc3fada64d1b45ac73346f816f9e90849c
parent53450493e0a13d835fd1d2457c49a9d60bee0e18 (diff)
Bug fix
-rw-r--r--notebooks/00-scratch-pad.ipynb (renamed from notebooks/00-testing-stuff-out.ipynb)51
-rw-r--r--text_recognizer/networks/transformer/nystromer/nystromer.py34
2 files changed, 69 insertions, 16 deletions
diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-scratch-pad.ipynb
index 12c5145..d50fd59 100644
--- a/notebooks/00-testing-stuff-out.ipynb
+++ b/notebooks/00-scratch-pad.ipynb
@@ -27,6 +27,57 @@
"cell_type": "code",
"execution_count": 2,
"metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "ModuleList(\n",
+ " (0): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (1): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (2): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (3): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (4): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (5): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (6): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (7): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (8): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (9): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nn.ModuleList([nn.ModuleList([nn.Linear(10, 10)]) for _ in range(10)])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
"outputs": [],
"source": [
"from omegaconf import OmegaConf"
diff --git a/text_recognizer/networks/transformer/nystromer/nystromer.py b/text_recognizer/networks/transformer/nystromer/nystromer.py
index 0283d69..7cc889e 100644
--- a/text_recognizer/networks/transformer/nystromer/nystromer.py
+++ b/text_recognizer/networks/transformer/nystromer/nystromer.py
@@ -30,24 +30,26 @@ class Nystromer(nn.Module):
super().__init__()
self.layers = nn.ModuleList(
[
- [
- PreNorm(
- dim,
- NystromAttention(
- dim=dim,
- dim_head=dim_head,
- num_heads=num_heads,
- num_landmarks=num_landmarks,
- inverse_iter=inverse_iter,
- residual=residual,
- residual_conv_kernel=residual_conv_kernel,
- dropout_rate=dropout_rate,
+ nn.ModuleList(
+ [
+ PreNorm(
+ dim,
+ NystromAttention(
+ dim=dim,
+ dim_head=dim_head,
+ num_heads=num_heads,
+ num_landmarks=num_landmarks,
+ inverse_iter=inverse_iter,
+ residual=residual,
+ residual_conv_kernel=residual_conv_kernel,
+ dropout_rate=dropout_rate,
+ ),
),
- ),
- PreNorm(dim, FeedForward(dim=dim, dropout_rate=dropout_rate)),
- ]
+ PreNorm(dim, FeedForward(dim=dim, dropout_rate=dropout_rate)),
+ ]
+ )
+ for _ in range(depth)
]
- for _ in range(depth)
)
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: