diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-09 00:36:55 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-09 00:36:55 +0200 |
commit | 548f52b35062e258622ea638ed1b132d6759a07a (patch) | |
tree | e9262d0f934ac4f9392f20cb4fcf7be6033e3cb7 /notebooks/00-scratch-pad.ipynb | |
parent | 805d5726c17b83e00dcea0d2608dcd83a91f723d (diff) |
Attention layer soon done
Diffstat (limited to 'notebooks/00-scratch-pad.ipynb')
-rw-r--r-- | notebooks/00-scratch-pad.ipynb | 436 |
1 files changed, 430 insertions, 6 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb index d50fd59..b6ec2c8 100644 --- a/notebooks/00-scratch-pad.ipynb +++ b/notebooks/00-scratch-pad.ipynb @@ -25,6 +25,410 @@ }, { "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.transformer.nystromer.nystromer import Nystromer" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "efficient_transformer = partial(Nystromer,\n", + " dim = 512,\n", + " depth = 12,\n", + " num_heads = 8,\n", + " num_landmarks = 256\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "functools.partial" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(efficient_transformer)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "efficient_transformer = efficient_transformer(num_landmarks=256)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Nystromer(\n", + " (layers): ModuleList(\n", + " (0): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (1): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (2): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (3): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (4): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (5): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (6): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (7): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (8): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (9): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (10): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (11): ModuleList(\n", + " (0): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): NystromAttention(\n", + " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n", + " (fc_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n", + " )\n", + " )\n", + " (1): PreNorm(\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (fn): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=512, out_features=4096, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "efficient_transformer()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "8" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(list(filter(lambda x: x == \"a\", (\"a\", \"c\") * 8)))" + ] + }, + { + "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ @@ -486,7 +890,7 @@ { "data": { "text/plain": [ - "144" + "18" ] }, "execution_count": 30, @@ -495,27 +899,47 @@ } ], "source": [ - "576 // 4" + "576 // 32" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "160" + "20" ] }, - "execution_count": 29, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "640 // 32" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "360" + ] + }, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "640 // 4" + "18 * 20" ] }, { |