summaryrefslogtreecommitdiff
path: root/notebooks
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks')
-rw-r--r--notebooks/00-scratch-pad.ipynb268
1 files changed, 231 insertions, 37 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb
index 0a68168..3c44f2b 100644
--- a/notebooks/00-scratch-pad.ipynb
+++ b/notebooks/00-scratch-pad.ipynb
@@ -2,9 +2,18 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 6,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ }
+ ],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
@@ -25,7 +34,216 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.cuda.is_available()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.transformer.layers import Decoder"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "decoder = Decoder(dim=256, depth=4, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Decoder(\n",
+ " (layers): ModuleList(\n",
+ " (0): ModuleList(\n",
+ " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): Attention(\n",
+ " (qkv_fn): Sequential(\n",
+ " (0): Linear(in_features=256, out_features=49152, bias=False)\n",
+ " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n",
+ " )\n",
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
+ " (fc): Linear(in_features=16384, out_features=256, bias=True)\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " (1): ModuleList(\n",
+ " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): Attention(\n",
+ " (qkv_fn): Sequential(\n",
+ " (0): Linear(in_features=256, out_features=49152, bias=False)\n",
+ " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n",
+ " )\n",
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
+ " (fc): Linear(in_features=16384, out_features=256, bias=True)\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " (2): ModuleList(\n",
+ " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=256, out_features=2048, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=1024, out_features=256, bias=True)\n",
+ " )\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " (3): ModuleList(\n",
+ " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): Attention(\n",
+ " (qkv_fn): Sequential(\n",
+ " (0): Linear(in_features=256, out_features=49152, bias=False)\n",
+ " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n",
+ " )\n",
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
+ " (fc): Linear(in_features=16384, out_features=256, bias=True)\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " (4): ModuleList(\n",
+ " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): Attention(\n",
+ " (qkv_fn): Sequential(\n",
+ " (0): Linear(in_features=256, out_features=49152, bias=False)\n",
+ " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n",
+ " )\n",
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
+ " (fc): Linear(in_features=16384, out_features=256, bias=True)\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " (5): ModuleList(\n",
+ " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=256, out_features=2048, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=1024, out_features=256, bias=True)\n",
+ " )\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " (6): ModuleList(\n",
+ " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): Attention(\n",
+ " (qkv_fn): Sequential(\n",
+ " (0): Linear(in_features=256, out_features=49152, bias=False)\n",
+ " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n",
+ " )\n",
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
+ " (fc): Linear(in_features=16384, out_features=256, bias=True)\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " (7): ModuleList(\n",
+ " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): Attention(\n",
+ " (qkv_fn): Sequential(\n",
+ " (0): Linear(in_features=256, out_features=49152, bias=False)\n",
+ " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n",
+ " )\n",
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
+ " (fc): Linear(in_features=16384, out_features=256, bias=True)\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " (8): ModuleList(\n",
+ " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=256, out_features=2048, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=1024, out_features=256, bias=True)\n",
+ " )\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " (9): ModuleList(\n",
+ " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): Attention(\n",
+ " (qkv_fn): Sequential(\n",
+ " (0): Linear(in_features=256, out_features=49152, bias=False)\n",
+ " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n",
+ " )\n",
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
+ " (fc): Linear(in_features=16384, out_features=256, bias=True)\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " (10): ModuleList(\n",
+ " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): Attention(\n",
+ " (qkv_fn): Sequential(\n",
+ " (0): Linear(in_features=256, out_features=49152, bias=False)\n",
+ " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n",
+ " )\n",
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
+ " (fc): Linear(in_features=16384, out_features=256, bias=True)\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " (11): ModuleList(\n",
+ " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=256, out_features=2048, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=1024, out_features=256, bias=True)\n",
+ " )\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "decoder.cuda()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -34,7 +252,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -43,21 +261,21 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"efficient_transformer = Nystromer(\n",
" dim = 128,\n",
- " depth = 8,\n",
- " num_heads = 6,\n",
+ " depth = 4,\n",
+ " num_heads = 8,\n",
" num_landmarks = 128\n",
")"
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -66,7 +284,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -80,7 +298,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -89,31 +307,9 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "ename": "RuntimeError",
- "evalue": "CUDA out of memory. Tried to allocate 12.00 MiB (GPU 0; 7.79 GiB total capacity; 6.44 GiB already allocated; 10.31 MiB free; 6.50 GiB reserved in total by PyTorch)",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m<ipython-input-8-996bed2e6057>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/vit.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpos_embedding\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 45\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransformer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 46\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/nystromer/nystromer.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, mask)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mattn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mff\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mattn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mff\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/norm.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, **kwargs)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;34m\"\"\"Norm tensor.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/nystromer/attention.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, mask, return_attn)\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0mq\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mq\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscale\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 169\u001b[0;31m \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nystrom_attention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_attn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 170\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;31m# Add depth-wise convolutional residual of values\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/nystromer/attention.py\u001b[0m in \u001b[0;36m_nystrom_attention\u001b[0;34m(self, q, k, v, mask, n, m, return_attn)\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;31m# Compute attention\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0mattn1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattn2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattn3\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0msim1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msim2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msim3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 148\u001b[0;31m \u001b[0mattn2_inv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmoore_penrose_inverse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mattn2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minverse_iter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mattn1\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mattn2_inv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mattn3\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/nystromer/attention.py\u001b[0m in \u001b[0;36mmoore_penrose_inverse\u001b[0;34m(x, iters)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miters\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0mxz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0.25\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m13\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mI\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mxz\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m15\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mI\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mxz\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m7\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mI\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mxz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 12.00 MiB (GPU 0; 7.79 GiB total capacity; 6.44 GiB already allocated; 10.31 MiB free; 6.50 GiB reserved in total by PyTorch)"
- ]
- }
- ],
+ "outputs": [],
"source": [
"v(t).shape"
]
@@ -337,9 +533,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "scrolled": false
- },
+ "metadata": {},
"outputs": [],
"source": [
"en(datum).shape"