summaryrefslogtreecommitdiff
path: root/notebooks/00-scratch-pad.ipynb
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-09 22:46:09 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-09 22:46:09 +0200
commitc9c60678673e19ad3367339eb8e7a093e5a98474 (patch)
treeb787a7fbb535c2ee44f935720d75034cc24ffd30 /notebooks/00-scratch-pad.ipynb
parenta2a3133ed5da283888efbdb9924d0e3733c274c8 (diff)
Reformatting of positional encodings and ViT working
Diffstat (limited to 'notebooks/00-scratch-pad.ipynb')
-rw-r--r--notebooks/00-scratch-pad.ipynb808
1 files changed, 124 insertions, 684 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb
index 0a5e2f3..0a68168 100644
--- a/notebooks/00-scratch-pad.ipynb
+++ b/notebooks/00-scratch-pad.ipynb
@@ -25,7 +25,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -34,7 +34,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -43,216 +43,138 @@
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
- "efficient_transformer = partial(Nystromer,\n",
- " dim = 512,\n",
- " depth = 12,\n",
- " num_heads = 8,\n",
- " num_landmarks = 256\n",
+ "efficient_transformer = Nystromer(\n",
+ " dim = 128,\n",
+ " depth = 8,\n",
+ " num_heads = 6,\n",
+ " num_landmarks = 128\n",
")"
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
- "from text_recognizer.networks.encoders.efficientnet import EfficientNet"
+ "from text_recognizer.networks.transformer.vit import ViT"
]
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
- "en = EfficientNet()"
+ "v = ViT(\n",
+ " dim = 128,\n",
+ " image_size = (576, 640),\n",
+ " patch_size = (32, 32),\n",
+ " transformer = efficient_transformer\n",
+ ").cuda()"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = torch.randn(32, 1, 576, 640).cuda()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "==========================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "==========================================================================================\n",
- "├─Sequential: 1-1 [-1, 256, 18, 20] --\n",
- "| └─ConvNorm: 2-1 [-1, 32, 288, 320] --\n",
- "| | └─Sequential: 3-1 [-1, 32, 288, 320] 352\n",
- "| └─InvertedResidulaBlock: 2-2 [-1, 16, 288, 320] --\n",
- "| | └─Sequential: 3-2 [-1, 16, 288, 320] 1,448\n",
- "| └─InvertedResidulaBlock: 2-3 [-1, 24, 144, 160] --\n",
- "| | └─ConvNorm: 3-3 [-1, 96, 288, 320] 14,016\n",
- "| | └─Sequential: 3-4 [-1, 24, 144, 160] 4,276\n",
- "| └─InvertedResidulaBlock: 2-4 [-1, 24, 144, 160] --\n",
- "| | └─ConvNorm: 3-5 [-1, 144, 144, 160] 31,392\n",
- "| | └─Sequential: 3-6 [-1, 24, 144, 160] 6,966\n",
- "| └─InvertedResidulaBlock: 2-5 [-1, 40, 72, 80] --\n",
- "| | └─ConvNorm: 3-7 [-1, 144, 144, 160] 31,392\n",
- "| | └─Sequential: 3-8 [-1, 40, 72, 80] 11,606\n",
- "| └─InvertedResidulaBlock: 2-6 [-1, 40, 72, 80] --\n",
- "| | └─ConvNorm: 3-9 [-1, 240, 72, 80] 86,880\n",
- "| | └─Sequential: 3-10 [-1, 40, 72, 80] 21,210\n",
- "| └─InvertedResidulaBlock: 2-7 [-1, 80, 36, 40] --\n",
- "| | └─ConvNorm: 3-11 [-1, 240, 72, 80] 86,880\n",
- "| | └─Sequential: 3-12 [-1, 80, 36, 40] 27,050\n",
- "| └─InvertedResidulaBlock: 2-8 [-1, 80, 36, 40] --\n",
- "| | └─ConvNorm: 3-13 [-1, 480, 36, 40] 346,560\n",
- "| | └─Sequential: 3-14 [-1, 80, 36, 40] 63,540\n",
- "| └─InvertedResidulaBlock: 2-9 [-1, 80, 36, 40] --\n",
- "| | └─ConvNorm: 3-15 [-1, 480, 36, 40] 346,560\n",
- "| | └─Sequential: 3-16 [-1, 80, 36, 40] 63,540\n",
- "| └─InvertedResidulaBlock: 2-10 [-1, 112, 36, 40] --\n",
- "| | └─ConvNorm: 3-17 [-1, 480, 36, 40] 346,560\n",
- "| | └─Sequential: 3-18 [-1, 112, 36, 40] 86,644\n",
- "| └─InvertedResidulaBlock: 2-11 [-1, 112, 36, 40] --\n",
- "| | └─ConvNorm: 3-19 [-1, 672, 36, 40] 678,720\n",
- "| | └─Sequential: 3-20 [-1, 112, 36, 40] 131,964\n",
- "| └─InvertedResidulaBlock: 2-12 [-1, 112, 36, 40] --\n",
- "| | └─ConvNorm: 3-21 [-1, 672, 36, 40] 678,720\n",
- "| | └─Sequential: 3-22 [-1, 112, 36, 40] 131,964\n",
- "| └─InvertedResidulaBlock: 2-13 [-1, 192, 18, 20] --\n",
- "| | └─ConvNorm: 3-23 [-1, 672, 36, 40] 678,720\n",
- "| | └─Sequential: 3-24 [-1, 192, 18, 20] 185,884\n",
- "| └─InvertedResidulaBlock: 2-14 [-1, 192, 18, 20] --\n",
- "| | └─ConvNorm: 3-25 [-1, 1152, 18, 20] 1,992,960\n",
- "| | └─Sequential: 3-26 [-1, 192, 18, 20] 364,464\n",
- "| └─InvertedResidulaBlock: 2-15 [-1, 192, 18, 20] --\n",
- "| | └─ConvNorm: 3-27 [-1, 1152, 18, 20] 1,992,960\n",
- "| | └─Sequential: 3-28 [-1, 192, 18, 20] 364,464\n",
- "| └─InvertedResidulaBlock: 2-16 [-1, 192, 18, 20] --\n",
- "| | └─ConvNorm: 3-29 [-1, 1152, 18, 20] 1,992,960\n",
- "| | └─Sequential: 3-30 [-1, 192, 18, 20] 364,464\n",
- "| └─InvertedResidulaBlock: 2-17 [-1, 320, 18, 20] --\n",
- "| | └─ConvNorm: 3-31 [-1, 1152, 18, 20] 1,992,960\n",
- "| | └─Sequential: 3-32 [-1, 320, 18, 20] 493,744\n",
- "| └─ConvNorm: 2-18 [-1, 256, 18, 20] --\n",
- "| | └─Sequential: 3-33 [-1, 256, 18, 20] 82,432\n",
- "==========================================================================================\n",
- "Total params: 13,704,252\n",
- "Trainable params: 13,704,252\n",
- "Non-trainable params: 0\n",
- "Total mult-adds (G): 1.23\n",
- "==========================================================================================\n",
- "Input size (MB): 1.41\n",
- "Forward/backward pass size (MB): 111.45\n",
- "Params size (MB): 52.28\n",
- "Estimated Total Size (MB): 165.13\n",
- "==========================================================================================\n"
+ "ename": "RuntimeError",
+ "evalue": "CUDA out of memory. Tried to allocate 12.00 MiB (GPU 0; 7.79 GiB total capacity; 6.44 GiB already allocated; 10.31 MiB free; 6.50 GiB reserved in total by PyTorch)",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m<ipython-input-8-996bed2e6057>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/vit.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpos_embedding\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 45\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransformer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 46\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/nystromer/nystromer.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, mask)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mattn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mff\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mattn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mff\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/norm.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, **kwargs)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;34m\"\"\"Norm tensor.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/nystromer/attention.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, mask, return_attn)\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0mq\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mq\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscale\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 169\u001b[0;31m \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nystrom_attention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_attn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 170\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;31m# Add depth-wise convolutional residual of values\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/nystromer/attention.py\u001b[0m in \u001b[0;36m_nystrom_attention\u001b[0;34m(self, q, k, v, mask, n, m, return_attn)\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;31m# Compute attention\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0mattn1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattn2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattn3\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0msim1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msim2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msim3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 148\u001b[0;31m \u001b[0mattn2_inv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmoore_penrose_inverse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mattn2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minverse_iter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mattn1\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mattn2_inv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mattn3\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/nystromer/attention.py\u001b[0m in \u001b[0;36mmoore_penrose_inverse\u001b[0;34m(x, iters)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miters\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0mxz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0.25\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m13\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mI\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mxz\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m15\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mI\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mxz\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m7\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mI\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mxz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 12.00 MiB (GPU 0; 7.79 GiB total capacity; 6.44 GiB already allocated; 10.31 MiB free; 6.50 GiB reserved in total by PyTorch)"
]
- },
- {
- "data": {
- "text/plain": [
- "==========================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "==========================================================================================\n",
- "├─Sequential: 1-1 [-1, 256, 18, 20] --\n",
- "| └─ConvNorm: 2-1 [-1, 32, 288, 320] --\n",
- "| | └─Sequential: 3-1 [-1, 32, 288, 320] 352\n",
- "| └─InvertedResidulaBlock: 2-2 [-1, 16, 288, 320] --\n",
- "| | └─Sequential: 3-2 [-1, 16, 288, 320] 1,448\n",
- "| └─InvertedResidulaBlock: 2-3 [-1, 24, 144, 160] --\n",
- "| | └─ConvNorm: 3-3 [-1, 96, 288, 320] 14,016\n",
- "| | └─Sequential: 3-4 [-1, 24, 144, 160] 4,276\n",
- "| └─InvertedResidulaBlock: 2-4 [-1, 24, 144, 160] --\n",
- "| | └─ConvNorm: 3-5 [-1, 144, 144, 160] 31,392\n",
- "| | └─Sequential: 3-6 [-1, 24, 144, 160] 6,966\n",
- "| └─InvertedResidulaBlock: 2-5 [-1, 40, 72, 80] --\n",
- "| | └─ConvNorm: 3-7 [-1, 144, 144, 160] 31,392\n",
- "| | └─Sequential: 3-8 [-1, 40, 72, 80] 11,606\n",
- "| └─InvertedResidulaBlock: 2-6 [-1, 40, 72, 80] --\n",
- "| | └─ConvNorm: 3-9 [-1, 240, 72, 80] 86,880\n",
- "| | └─Sequential: 3-10 [-1, 40, 72, 80] 21,210\n",
- "| └─InvertedResidulaBlock: 2-7 [-1, 80, 36, 40] --\n",
- "| | └─ConvNorm: 3-11 [-1, 240, 72, 80] 86,880\n",
- "| | └─Sequential: 3-12 [-1, 80, 36, 40] 27,050\n",
- "| └─InvertedResidulaBlock: 2-8 [-1, 80, 36, 40] --\n",
- "| | └─ConvNorm: 3-13 [-1, 480, 36, 40] 346,560\n",
- "| | └─Sequential: 3-14 [-1, 80, 36, 40] 63,540\n",
- "| └─InvertedResidulaBlock: 2-9 [-1, 80, 36, 40] --\n",
- "| | └─ConvNorm: 3-15 [-1, 480, 36, 40] 346,560\n",
- "| | └─Sequential: 3-16 [-1, 80, 36, 40] 63,540\n",
- "| └─InvertedResidulaBlock: 2-10 [-1, 112, 36, 40] --\n",
- "| | └─ConvNorm: 3-17 [-1, 480, 36, 40] 346,560\n",
- "| | └─Sequential: 3-18 [-1, 112, 36, 40] 86,644\n",
- "| └─InvertedResidulaBlock: 2-11 [-1, 112, 36, 40] --\n",
- "| | └─ConvNorm: 3-19 [-1, 672, 36, 40] 678,720\n",
- "| | └─Sequential: 3-20 [-1, 112, 36, 40] 131,964\n",
- "| └─InvertedResidulaBlock: 2-12 [-1, 112, 36, 40] --\n",
- "| | └─ConvNorm: 3-21 [-1, 672, 36, 40] 678,720\n",
- "| | └─Sequential: 3-22 [-1, 112, 36, 40] 131,964\n",
- "| └─InvertedResidulaBlock: 2-13 [-1, 192, 18, 20] --\n",
- "| | └─ConvNorm: 3-23 [-1, 672, 36, 40] 678,720\n",
- "| | └─Sequential: 3-24 [-1, 192, 18, 20] 185,884\n",
- "| └─InvertedResidulaBlock: 2-14 [-1, 192, 18, 20] --\n",
- "| | └─ConvNorm: 3-25 [-1, 1152, 18, 20] 1,992,960\n",
- "| | └─Sequential: 3-26 [-1, 192, 18, 20] 364,464\n",
- "| └─InvertedResidulaBlock: 2-15 [-1, 192, 18, 20] --\n",
- "| | └─ConvNorm: 3-27 [-1, 1152, 18, 20] 1,992,960\n",
- "| | └─Sequential: 3-28 [-1, 192, 18, 20] 364,464\n",
- "| └─InvertedResidulaBlock: 2-16 [-1, 192, 18, 20] --\n",
- "| | └─ConvNorm: 3-29 [-1, 1152, 18, 20] 1,992,960\n",
- "| | └─Sequential: 3-30 [-1, 192, 18, 20] 364,464\n",
- "| └─InvertedResidulaBlock: 2-17 [-1, 320, 18, 20] --\n",
- "| | └─ConvNorm: 3-31 [-1, 1152, 18, 20] 1,992,960\n",
- "| | └─Sequential: 3-32 [-1, 320, 18, 20] 493,744\n",
- "| └─ConvNorm: 2-18 [-1, 256, 18, 20] --\n",
- "| | └─Sequential: 3-33 [-1, 256, 18, 20] 82,432\n",
- "==========================================================================================\n",
- "Total params: 13,704,252\n",
- "Trainable params: 13,704,252\n",
- "Non-trainable params: 0\n",
- "Total mult-adds (G): 1.23\n",
- "==========================================================================================\n",
- "Input size (MB): 1.41\n",
- "Forward/backward pass size (MB): 111.45\n",
- "Params size (MB): 52.28\n",
- "Estimated Total Size (MB): 165.13\n",
- "=========================================================================================="
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
}
],
"source": [
+ "v(t).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.encoders.efficientnet import EfficientNet"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "en = EfficientNet()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "(576, 640) // (8, 8)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "(576 // 32) ** 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"summary(en, (1, 576, 640))"
]
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "functools.partial"
- ]
- },
- "execution_count": 28,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"type(efficient_transformer)"
]
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -261,330 +183,16 @@
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Nystromer(\n",
- " (layers): ModuleList(\n",
- " (0): ModuleList(\n",
- " (0): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): NystromAttention(\n",
- " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
- " (fc_out): Sequential(\n",
- " (0): Linear(in_features=512, out_features=512, bias=True)\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " )\n",
- " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
- " )\n",
- " )\n",
- " (1): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
- " )\n",
- " )\n",
- " )\n",
- " )\n",
- " (1): ModuleList(\n",
- " (0): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): NystromAttention(\n",
- " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
- " (fc_out): Sequential(\n",
- " (0): Linear(in_features=512, out_features=512, bias=True)\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " )\n",
- " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
- " )\n",
- " )\n",
- " (1): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
- " )\n",
- " )\n",
- " )\n",
- " )\n",
- " (2): ModuleList(\n",
- " (0): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): NystromAttention(\n",
- " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
- " (fc_out): Sequential(\n",
- " (0): Linear(in_features=512, out_features=512, bias=True)\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " )\n",
- " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
- " )\n",
- " )\n",
- " (1): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
- " )\n",
- " )\n",
- " )\n",
- " )\n",
- " (3): ModuleList(\n",
- " (0): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): NystromAttention(\n",
- " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
- " (fc_out): Sequential(\n",
- " (0): Linear(in_features=512, out_features=512, bias=True)\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " )\n",
- " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
- " )\n",
- " )\n",
- " (1): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
- " )\n",
- " )\n",
- " )\n",
- " )\n",
- " (4): ModuleList(\n",
- " (0): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): NystromAttention(\n",
- " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
- " (fc_out): Sequential(\n",
- " (0): Linear(in_features=512, out_features=512, bias=True)\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " )\n",
- " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
- " )\n",
- " )\n",
- " (1): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
- " )\n",
- " )\n",
- " )\n",
- " )\n",
- " (5): ModuleList(\n",
- " (0): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): NystromAttention(\n",
- " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
- " (fc_out): Sequential(\n",
- " (0): Linear(in_features=512, out_features=512, bias=True)\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " )\n",
- " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
- " )\n",
- " )\n",
- " (1): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
- " )\n",
- " )\n",
- " )\n",
- " )\n",
- " (6): ModuleList(\n",
- " (0): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): NystromAttention(\n",
- " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
- " (fc_out): Sequential(\n",
- " (0): Linear(in_features=512, out_features=512, bias=True)\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " )\n",
- " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
- " )\n",
- " )\n",
- " (1): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
- " )\n",
- " )\n",
- " )\n",
- " )\n",
- " (7): ModuleList(\n",
- " (0): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): NystromAttention(\n",
- " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
- " (fc_out): Sequential(\n",
- " (0): Linear(in_features=512, out_features=512, bias=True)\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " )\n",
- " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
- " )\n",
- " )\n",
- " (1): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
- " )\n",
- " )\n",
- " )\n",
- " )\n",
- " (8): ModuleList(\n",
- " (0): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): NystromAttention(\n",
- " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
- " (fc_out): Sequential(\n",
- " (0): Linear(in_features=512, out_features=512, bias=True)\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " )\n",
- " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
- " )\n",
- " )\n",
- " (1): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
- " )\n",
- " )\n",
- " )\n",
- " )\n",
- " (9): ModuleList(\n",
- " (0): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): NystromAttention(\n",
- " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
- " (fc_out): Sequential(\n",
- " (0): Linear(in_features=512, out_features=512, bias=True)\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " )\n",
- " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
- " )\n",
- " )\n",
- " (1): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
- " )\n",
- " )\n",
- " )\n",
- " )\n",
- " (10): ModuleList(\n",
- " (0): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): NystromAttention(\n",
- " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
- " (fc_out): Sequential(\n",
- " (0): Linear(in_features=512, out_features=512, bias=True)\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " )\n",
- " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
- " )\n",
- " )\n",
- " (1): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
- " )\n",
- " )\n",
- " )\n",
- " )\n",
- " (11): ModuleList(\n",
- " (0): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): NystromAttention(\n",
- " (qkv_fn): Linear(in_features=512, out_features=1536, bias=False)\n",
- " (fc_out): Sequential(\n",
- " (0): Linear(in_features=512, out_features=512, bias=True)\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " )\n",
- " (residual): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)\n",
- " )\n",
- " )\n",
- " (1): PreNorm(\n",
- " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
- " (fn): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=512, out_features=4096, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
- " )\n",
- " )\n",
- " )\n",
- " )\n",
- " )\n",
- ")"
- ]
- },
- "execution_count": 29,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"efficient_transformer()"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -692,7 +300,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -701,7 +309,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -739,7 +347,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -748,7 +356,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -757,107 +365,16 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "seed: 4711\n",
- "network:\n",
- " desc: Configuration of the PyTorch neural network.\n",
- " type: CNNTransformer\n",
- " args:\n",
- " encoder:\n",
- " type: EfficientNet\n",
- " args: null\n",
- " num_decoder_layers: 4\n",
- " hidden_dim: 256\n",
- " num_heads: 4\n",
- " expansion_dim: 1024\n",
- " dropout_rate: 0.1\n",
- " transformer_activation: glu\n",
- "model:\n",
- " desc: Configuration of the PyTorch Lightning model.\n",
- " type: LitTransformerModel\n",
- " args:\n",
- " optimizer:\n",
- " type: MADGRAD\n",
- " args:\n",
- " lr: 0.001\n",
- " momentum: 0.9\n",
- " weight_decay: 0\n",
- " eps: 1.0e-06\n",
- " lr_scheduler:\n",
- " type: OneCycleLR\n",
- " args:\n",
- " interval: step\n",
- " max_lr: 0.001\n",
- " three_phase: true\n",
- " epochs: 512\n",
- " steps_per_epoch: 1246\n",
- " criterion:\n",
- " type: CrossEntropyLoss\n",
- " args:\n",
- " weight: None\n",
- " ignore_index: -100\n",
- " reduction: mean\n",
- " monitor: val_loss\n",
- " mapping: sentence_piece\n",
- "data:\n",
- " desc: Configuration of the training/test data.\n",
- " type: IAMExtendedParagraphs\n",
- " args:\n",
- " batch_size: 16\n",
- " num_workers: 12\n",
- " train_fraction: 0.8\n",
- " augment: true\n",
- "callbacks:\n",
- "- type: ModelCheckpoint\n",
- " args:\n",
- " monitor: val_loss\n",
- " mode: min\n",
- " save_last: true\n",
- "- type: StochasticWeightAveraging\n",
- " args:\n",
- " swa_epoch_start: 0.8\n",
- " swa_lrs: 0.05\n",
- " annealing_epochs: 10\n",
- " annealing_strategy: cos\n",
- " device: null\n",
- "- type: LearningRateMonitor\n",
- " args:\n",
- " logging_interval: step\n",
- "- type: EarlyStopping\n",
- " args:\n",
- " monitor: val_loss\n",
- " mode: min\n",
- " patience: 10\n",
- "trainer:\n",
- " desc: Configuration of the PyTorch Lightning Trainer.\n",
- " args:\n",
- " stochastic_weight_avg: true\n",
- " auto_scale_batch_size: binsearch\n",
- " gradient_clip_val: 0\n",
- " fast_dev_run: false\n",
- " gpus: 1\n",
- " precision: 16\n",
- " max_epochs: 512\n",
- " terminate_on_nan: true\n",
- " weights_summary: true\n",
- "load_checkpoint: null\n",
- "\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"print(OmegaConf.to_yaml(conf))"
]
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -866,7 +383,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -893,20 +410,9 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 682, 1004])"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"t(datum, trg).shape"
]
@@ -920,7 +426,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -930,7 +436,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -939,47 +445,25 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([16, 128])"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"x().shape"
]
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([16, 128])"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"torch.ones((b, n), device=device).bool().shape"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -988,67 +472,34 @@
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "18"
- ]
- },
- "execution_count": 30,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"576 // 32"
]
},
{
"cell_type": "code",
- "execution_count": 31,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "20"
- ]
- },
- "execution_count": 31,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"640 // 32"
]
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "360"
- ]
- },
- "execution_count": 32,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"18 * 20"
]
},
{
"cell_type": "code",
- "execution_count": 31,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1057,7 +508,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1066,7 +517,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1076,20 +527,9 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([1, 1440, 256])"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"p.shape"
]