summaryrefslogtreecommitdiff
path: root/notebooks
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks')
-rw-r--r--notebooks/00-scratch-pad.ipynb436
1 files changed, 430 insertions, 6 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb
index d50fd59..b6ec2c8 100644
--- a/notebooks/00-scratch-pad.ipynb
+++ b/notebooks/00-scratch-pad.ipynb
@@ -25,6 +25,410 @@
},
{
"cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.transformer.nystromer.nystromer import Nystromer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from functools import partial"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "efficient_transformer = partial(Nystromer,\n",
+ " dim = 512,\n",
+ " depth = 12,\n",
+ " num_heads = 8,\n",
+ " num_landmarks = 256\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "functools.partial"
+ ]
+ },
+ "execution_count": 28,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "type(efficient_transformer)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "efficient_transformer = efficient_transformer(num_landmarks=256)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Nystromer(\n",
+ " (layers): ModuleList(\n",
+ " (0): ModuleList(\n",
+ " (0): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): NystromAttention(\n",
+ " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
+ " (fc_out): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
+ " )\n",
+ " )\n",
+ " (1): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (1): ModuleList(\n",
+ " (0): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): NystromAttention(\n",
+ " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
+ " (fc_out): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
+ " )\n",
+ " )\n",
+ " (1): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (2): ModuleList(\n",
+ " (0): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): NystromAttention(\n",
+ " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
+ " (fc_out): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
+ " )\n",
+ " )\n",
+ " (1): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (3): ModuleList(\n",
+ " (0): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): NystromAttention(\n",
+ " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
+ " (fc_out): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
+ " )\n",
+ " )\n",
+ " (1): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (4): ModuleList(\n",
+ " (0): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): NystromAttention(\n",
+ " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
+ " (fc_out): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
+ " )\n",
+ " )\n",
+ " (1): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (5): ModuleList(\n",
+ " (0): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): NystromAttention(\n",
+ " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
+ " (fc_out): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
+ " )\n",
+ " )\n",
+ " (1): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (6): ModuleList(\n",
+ " (0): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): NystromAttention(\n",
+ " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
+ " (fc_out): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
+ " )\n",
+ " )\n",
+ " (1): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (7): ModuleList(\n",
+ " (0): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): NystromAttention(\n",
+ " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
+ " (fc_out): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
+ " )\n",
+ " )\n",
+ " (1): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (8): ModuleList(\n",
+ " (0): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): NystromAttention(\n",
+ " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
+ " (fc_out): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
+ " )\n",
+ " )\n",
+ " (1): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (9): ModuleList(\n",
+ " (0): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): NystromAttention(\n",
+ " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
+ " (fc_out): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
+ " )\n",
+ " )\n",
+ " (1): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (10): ModuleList(\n",
+ " (0): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): NystromAttention(\n",
+ " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
+ " (fc_out): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
+ " )\n",
+ " )\n",
+ " (1): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (11): ModuleList(\n",
+ " (0): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): NystromAttention(\n",
+ " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
+ " (fc_out): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
+ " )\n",
+ " )\n",
+ " (1): PreNorm(\n",
+ " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
+ " (fn): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 29,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "efficient_transformer()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "8"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(list(filter(lambda x: x == \"a\", (\"a\", \"c\") * 8)))"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
@@ -486,7 +890,7 @@
{
"data": {
"text/plain": [
- "144"
+ "18"
]
},
"execution_count": 30,
@@ -495,27 +899,47 @@
}
],
"source": [
- "576 // 4"
+ "576 // 32"
]
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "160"
+ "20"
]
},
- "execution_count": 29,
+ "execution_count": 31,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "640 // 32"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "360"
+ ]
+ },
+ "execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "640 // 4"
+ "18 * 20"
]
},
{