summaryrefslogtreecommitdiff
path: root/src/notebooks/00-testing-stuff-out.ipynb
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2021-01-24 22:14:17 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2021-01-24 22:14:17 +0100
commit4a54d7e690897dd6e6c719fb908fd371a44c2952 (patch)
tree04722ac94b9c3960baa5db7939d7ef01dbf535a6 /src/notebooks/00-testing-stuff-out.ipynb
parentd691b548cd0b6fc4ea184d64261f633789fee021 (diff)
Many updates, cool stuff on the way.
Diffstat (limited to 'src/notebooks/00-testing-stuff-out.ipynb')
-rw-r--r--src/notebooks/00-testing-stuff-out.ipynb312
1 files changed, 236 insertions, 76 deletions
diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb
index b5fdbe0..0e4b298 100644
--- a/src/notebooks/00-testing-stuff-out.ipynb
+++ b/src/notebooks/00-testing-stuff-out.ipynb
@@ -16,6 +16,7 @@
"import torch.nn.functional as F\n",
"import torch\n",
"from torch import nn\n",
+ "from torchsummary import summary\n",
"from importlib.util import find_spec\n",
"if find_spec(\"text_recognizer\") is None:\n",
" import sys\n",
@@ -24,73 +25,76 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
- "from text_recognizer.networks import CTCTransformer"
+ "from text_recognizer.networks import CNN"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
- "model = CTCTransformer(\n",
- " num_encoder_layers=2,\n",
- " hidden_dim=256,\n",
- " vocab_size=56,\n",
- " num_heads=8,\n",
- " adaptive_pool_dim=[None, 1],\n",
- " expansion_dim=2048,\n",
- " dropout_rate=0.1,\n",
- " max_len=256,\n",
- " patch_size=(28, 32),\n",
- " stride=(1, 28),\n",
- " activation=\"gelu\",\n",
- " backbone=\"WideResidualNetwork\",\n",
- "backbone_args={\n",
- " \"in_channels\": 1,\n",
- " \"in_planes\": 64,\n",
- " \"num_classes\": 80,\n",
- " \"depth\": 10,\n",
- " \"width_factor\": 1,\n",
- " \"dropout_rate\": 0.1,\n",
- " \"num_layers\": 4,\n",
- " \"num_stages\": [64, 128, 256, 256],\n",
- " \"activation\": \"elu\",\n",
- " \"use_decoder\": False,\n",
- "},\n",
- " )"
+ "cnn = CNN()"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 79,
"metadata": {},
"outputs": [],
- "source": []
+ "source": [
+ "i = nn.Sequential(nn.Conv2d(1,1,1,1))"
+ ]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 81,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Sequential(\n",
+ " (0): Sequential(\n",
+ " (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (1): Sequential(\n",
+ " (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 81,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nn.Sequential(i,i)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 64,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 128, 1, 59])"
+ ]
+ },
+ "execution_count": 64,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "backbone: WideResidualNetwork\n",
- " backbone_args:\n",
- " in_channels: 1\n",
- " in_planes: 64\n",
- " num_classes: 80\n",
- " depth: 10\n",
- " width_factor: 1\n",
- " dropout_rate: 0.1\n",
- " num_layers: 4 \n",
- " num_stages: [64, 128, 256, 256]\n",
- " activation: elu\n",
- " use_decoder: false\n",
- " n"
+ "cnn(t).shape"
]
},
{
@@ -99,80 +103,236 @@
"metadata": {},
"outputs": [],
"source": [
+ "from text_recognizer.networks.vqvae import Encoder, Decoder, VQVAE"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vqvae = VQVAE(1, [32, 128, 128, 256], [4, 4, 4, 4], [2, 2, [1, 2], [1, 2]], 2, 32, 256, [[6, 119], [7, 238]])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"t = torch.randn(2, 1, 28, 952)"
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x, l = vqvae(t)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "torch.Size([56, 952])"
+ "29.5"
]
},
- "execution_count": 3,
+ "execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "t.view(-1, 952).shape"
+ "5 * 59 / 10"
]
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "torch.Size([119, 2, 56])"
+ "torch.Size([2, 1, 28, 952])"
]
},
- "execution_count": 14,
+ "execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "model(t).shape"
+ "x.shape"
]
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 26,
"metadata": {},
"outputs": [
{
- "ename": "RuntimeError",
- "evalue": "Failed to run torchsummary. See above stack traces for more details. Executed layers up to: [WideResidualNetwork: 1-1, Sequential: 2-1, Conv2d: 3-1, Sequential: 3-2, WideBlock: 4-1, Sequential: 3-3, WideBlock: 4-2, Sequential: 3-4, WideBlock: 4-3, Sequential: 3-5, WideBlock: 4-4, AdaptiveAvgPool2d: 1-2, Encoder: 1-3, EncoderLayer: 3-6, MultiHeadAttention: 4-5, _IntraLayerConnection: 4-6, _ConvolutionalLayer: 4-7, _IntraLayerConnection: 4-8, EncoderLayer: 3-7, MultiHeadAttention: 4-9, _IntraLayerConnection: 4-10, _ConvolutionalLayer: 4-11, _IntraLayerConnection: 4-12, LayerNorm: 2-2, Linear: 2-3, GLU: 2-4]",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m----------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torchsummary/torchsummary.py\u001b[0m in \u001b[0;36msummary\u001b[0;34m(model, input_data, batch_dim, branching, col_names, col_width, depth, device, dtypes, verbose, *args, **kwargs)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 125\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n",
- "\u001b[0;32m~/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/text_recognizer/networks/ctc_transformer.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, trg)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontext_representation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimage_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhead\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 111\u001b[0m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrearrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"b t y -> t b y\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mlinear\u001b[0;34m(input, weight, bias)\u001b[0m\n\u001b[1;32m 1691\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1692\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1693\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (238x128 and 256x56)",
- "\nThe above exception was the direct cause of the following exception:\n",
- "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m<ipython-input-8-85c5209ae40a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0msummary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m28\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m952\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdepth\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torchsummary/torchsummary.py\u001b[0m in \u001b[0;36msummary\u001b[0;34m(model, input_data, batch_dim, branching, col_names, col_width, depth, device, dtypes, verbose, *args, **kwargs)\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0mexecuted_layers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mlayer\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msummary_list\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecuted\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m raise RuntimeError(\n\u001b[0m\u001b[1;32m 128\u001b[0m \u001b[0;34m\"Failed to run torchsummary. See above stack traces for more details. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\"Executed layers up to: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexecuted_layers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mRuntimeError\u001b[0m: Failed to run torchsummary. See above stack traces for more details. Executed layers up to: [WideResidualNetwork: 1-1, Sequential: 2-1, Conv2d: 3-1, Sequential: 3-2, WideBlock: 4-1, Sequential: 3-3, WideBlock: 4-2, Sequential: 3-4, WideBlock: 4-3, Sequential: 3-5, WideBlock: 4-4, AdaptiveAvgPool2d: 1-2, Encoder: 1-3, EncoderLayer: 3-6, MultiHeadAttention: 4-5, _IntraLayerConnection: 4-6, _ConvolutionalLayer: 4-7, _IntraLayerConnection: 4-8, EncoderLayer: 3-7, MultiHeadAttention: 4-9, _IntraLayerConnection: 4-10, _ConvolutionalLayer: 4-11, _IntraLayerConnection: 4-12, LayerNorm: 2-2, Linear: 2-3, GLU: 2-4]"
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "===============================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "===============================================================================================\n",
+ "├─Encoder: 1-1 [-1, 32, 5, 59] --\n",
+ "| └─Sequential: 2-1 [-1, 32, 5, 59] --\n",
+ "| | └─Sequential: 3-1 [-1, 32, 14, 476] 544\n",
+ "| | └─Sequential: 3-2 [-1, 128, 7, 238] 65,664\n",
+ "| | └─Sequential: 3-3 [-1, 128, 6, 119] 262,272\n",
+ "| | └─Sequential: 3-4 [-1, 256, 5, 59] 524,544\n",
+ "| | └─_ResidualBlock: 3-5 [-1, 256, 5, 59] 655,360\n",
+ "| | └─_ResidualBlock: 3-6 [-1, 256, 5, 59] 655,360\n",
+ "| | └─Conv2d: 3-7 [-1, 32, 5, 59] 8,224\n",
+ "| └─VectorQuantizer: 2-2 [-1, 32, 5, 59] --\n",
+ "├─Decoder: 1-2 [-1, 1, 28, 952] --\n",
+ "| └─Sequential: 2-3 [-1, 1, 28, 952] --\n",
+ "| └─Sequential: 2-4 [-1, 256, 5, 59] --\n",
+ "| └─Sequential: 2 [] --\n",
+ "| | └─Sequential: 3-8 [-1, 256, 5, 59] (recursive)\n",
+ "| └─Sequential: 2 [] --\n",
+ "| | └─Conv2d: 3-9 [-1, 256, 5, 59] 8,448\n",
+ "| | └─_ResidualBlock: 3-10 [-1, 256, 5, 59] 655,360\n",
+ "| | └─_ResidualBlock: 3-11 [-1, 256, 5, 59] 655,360\n",
+ "| └─Sequential: 2-5 [-1, 1, 28, 952] --\n",
+ "| └─Sequential: 2 [] --\n",
+ "| | └─Sequential: 3-12 [-1, 1, 28, 952] (recursive)\n",
+ "| └─Sequential: 2 [] --\n",
+ "| | └─Sequential: 3-13 [-1, 128, 6, 118] 524,416\n",
+ "| | └─Upsample: 3-14 [-1, 128, 6, 119] --\n",
+ "| | └─Sequential: 3-15 [-1, 128, 7, 238] 262,272\n",
+ "| | └─Upsample: 3-16 [-1, 128, 7, 238] --\n",
+ "| | └─Sequential: 3-17 [-1, 32, 14, 476] 65,568\n",
+ "| | └─ConvTranspose2d: 3-18 [-1, 1, 28, 952] 513\n",
+ "| | └─Tanh: 3-19 [-1, 1, 28, 952] --\n",
+ "===============================================================================================\n",
+ "Total params: 4,343,905\n",
+ "Trainable params: 4,343,905\n",
+ "Non-trainable params: 0\n",
+ "Total mult-adds (G): 1.76\n",
+ "===============================================================================================\n",
+ "Input size (MB): 0.10\n",
+ "Forward/backward pass size (MB): 9.32\n",
+ "Params size (MB): 16.57\n",
+ "Estimated Total Size (MB): 26.00\n",
+ "===============================================================================================\n"
]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "===============================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "===============================================================================================\n",
+ "├─Encoder: 1-1 [-1, 32, 5, 59] --\n",
+ "| └─Sequential: 2-1 [-1, 32, 5, 59] --\n",
+ "| | └─Sequential: 3-1 [-1, 32, 14, 476] 544\n",
+ "| | └─Sequential: 3-2 [-1, 128, 7, 238] 65,664\n",
+ "| | └─Sequential: 3-3 [-1, 128, 6, 119] 262,272\n",
+ "| | └─Sequential: 3-4 [-1, 256, 5, 59] 524,544\n",
+ "| | └─_ResidualBlock: 3-5 [-1, 256, 5, 59] 655,360\n",
+ "| | └─_ResidualBlock: 3-6 [-1, 256, 5, 59] 655,360\n",
+ "| | └─Conv2d: 3-7 [-1, 32, 5, 59] 8,224\n",
+ "| └─VectorQuantizer: 2-2 [-1, 32, 5, 59] --\n",
+ "├─Decoder: 1-2 [-1, 1, 28, 952] --\n",
+ "| └─Sequential: 2-3 [-1, 1, 28, 952] --\n",
+ "| └─Sequential: 2-4 [-1, 256, 5, 59] --\n",
+ "| └─Sequential: 2 [] --\n",
+ "| | └─Sequential: 3-8 [-1, 256, 5, 59] (recursive)\n",
+ "| └─Sequential: 2 [] --\n",
+ "| | └─Conv2d: 3-9 [-1, 256, 5, 59] 8,448\n",
+ "| | └─_ResidualBlock: 3-10 [-1, 256, 5, 59] 655,360\n",
+ "| | └─_ResidualBlock: 3-11 [-1, 256, 5, 59] 655,360\n",
+ "| └─Sequential: 2-5 [-1, 1, 28, 952] --\n",
+ "| └─Sequential: 2 [] --\n",
+ "| | └─Sequential: 3-12 [-1, 1, 28, 952] (recursive)\n",
+ "| └─Sequential: 2 [] --\n",
+ "| | └─Sequential: 3-13 [-1, 128, 6, 118] 524,416\n",
+ "| | └─Upsample: 3-14 [-1, 128, 6, 119] --\n",
+ "| | └─Sequential: 3-15 [-1, 128, 7, 238] 262,272\n",
+ "| | └─Upsample: 3-16 [-1, 128, 7, 238] --\n",
+ "| | └─Sequential: 3-17 [-1, 32, 14, 476] 65,568\n",
+ "| | └─ConvTranspose2d: 3-18 [-1, 1, 28, 952] 513\n",
+ "| | └─Tanh: 3-19 [-1, 1, 28, 952] --\n",
+ "===============================================================================================\n",
+ "Total params: 4,343,905\n",
+ "Trainable params: 4,343,905\n",
+ "Non-trainable params: 0\n",
+ "Total mult-adds (G): 1.76\n",
+ "===============================================================================================\n",
+ "Input size (MB): 0.10\n",
+ "Forward/backward pass size (MB): 9.32\n",
+ "Params size (MB): 16.57\n",
+ "Estimated Total Size (MB): 26.00\n",
+ "==============================================================================================="
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "summary(vqvae, (1, 28, 952), device=\"cpu\", depth=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 94,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "up = nn.Upsample([4, 59])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 107,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 32, 4, 59])"
+ ]
+ },
+ "execution_count": 107,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "up(tt).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 104,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 32, 1, 59])"
+ ]
+ },
+ "execution_count": 104,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
- "summary(model, (1, 28, 952), device=\"cpu\", depth=3)"
+ "tt.shape"
]
},
{