diff options
Diffstat (limited to 'src/notebooks/00-testing-stuff-out.ipynb')
-rw-r--r-- | src/notebooks/00-testing-stuff-out.ipynb | 312 |
1 files changed, 236 insertions, 76 deletions
diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb index b5fdbe0..0e4b298 100644 --- a/src/notebooks/00-testing-stuff-out.ipynb +++ b/src/notebooks/00-testing-stuff-out.ipynb @@ -16,6 +16,7 @@ "import torch.nn.functional as F\n", "import torch\n", "from torch import nn\n", + "from torchsummary import summary\n", "from importlib.util import find_spec\n", "if find_spec(\"text_recognizer\") is None:\n", " import sys\n", @@ -24,73 +25,76 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 53, "metadata": {}, "outputs": [], "source": [ - "from text_recognizer.networks import CTCTransformer" + "from text_recognizer.networks import CNN" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 63, "metadata": {}, "outputs": [], "source": [ - "model = CTCTransformer(\n", - " num_encoder_layers=2,\n", - " hidden_dim=256,\n", - " vocab_size=56,\n", - " num_heads=8,\n", - " adaptive_pool_dim=[None, 1],\n", - " expansion_dim=2048,\n", - " dropout_rate=0.1,\n", - " max_len=256,\n", - " patch_size=(28, 32),\n", - " stride=(1, 28),\n", - " activation=\"gelu\",\n", - " backbone=\"WideResidualNetwork\",\n", - "backbone_args={\n", - " \"in_channels\": 1,\n", - " \"in_planes\": 64,\n", - " \"num_classes\": 80,\n", - " \"depth\": 10,\n", - " \"width_factor\": 1,\n", - " \"dropout_rate\": 0.1,\n", - " \"num_layers\": 4,\n", - " \"num_stages\": [64, 128, 256, 256],\n", - " \"activation\": \"elu\",\n", - " \"use_decoder\": False,\n", - "},\n", - " )" + "cnn = CNN()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 79, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "i = nn.Sequential(nn.Conv2d(1,1,1,1))" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 81, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Sequential(\n", + " (0): Sequential(\n", + " (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (1): Sequential(\n", + " (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + ")" + ] + }, + "execution_count": 81, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nn.Sequential(i,i)" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 128, 1, 59])" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "backbone: WideResidualNetwork\n", - " backbone_args:\n", - " in_channels: 1\n", - " in_planes: 64\n", - " num_classes: 80\n", - " depth: 10\n", - " width_factor: 1\n", - " dropout_rate: 0.1\n", - " num_layers: 4 \n", - " num_stages: [64, 128, 256, 256]\n", - " activation: elu\n", - " use_decoder: false\n", - " n" + "cnn(t).shape" ] }, { @@ -99,80 +103,236 @@ "metadata": {}, "outputs": [], "source": [ + "from text_recognizer.networks.vqvae import Encoder, Decoder, VQVAE" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "vqvae = VQVAE(1, [32, 128, 128, 256], [4, 4, 4, 4], [2, 2, [1, 2], [1, 2]], 2, 32, 256, [[6, 119], [7, 238]])" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ "t = torch.randn(2, 1, 28, 952)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "x, l = vqvae(t)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "torch.Size([56, 952])" + "29.5" ] }, - "execution_count": 3, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "t.view(-1, 952).shape" + "5 * 59 / 10" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "torch.Size([119, 2, 56])" + "torch.Size([2, 1, 28, 952])" ] }, - "execution_count": 14, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "model(t).shape" + "x.shape" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 26, "metadata": {}, "outputs": [ { - "ename": "RuntimeError", - "evalue": "Failed to run torchsummary. See above stack traces for more details. Executed layers up to: [WideResidualNetwork: 1-1, Sequential: 2-1, Conv2d: 3-1, Sequential: 3-2, WideBlock: 4-1, Sequential: 3-3, WideBlock: 4-2, Sequential: 3-4, WideBlock: 4-3, Sequential: 3-5, WideBlock: 4-4, AdaptiveAvgPool2d: 1-2, Encoder: 1-3, EncoderLayer: 3-6, MultiHeadAttention: 4-5, _IntraLayerConnection: 4-6, _ConvolutionalLayer: 4-7, _IntraLayerConnection: 4-8, EncoderLayer: 3-7, MultiHeadAttention: 4-9, _IntraLayerConnection: 4-10, _ConvolutionalLayer: 4-11, _IntraLayerConnection: 4-12, LayerNorm: 2-2, Linear: 2-3, GLU: 2-4]", - "output_type": "error", - "traceback": [ - "\u001b[0;31m----------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torchsummary/torchsummary.py\u001b[0m in \u001b[0;36msummary\u001b[0;34m(model, input_data, batch_dim, branching, col_names, col_width, depth, device, dtypes, verbose, *args, **kwargs)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 125\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n", - "\u001b[0;32m~/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/text_recognizer/networks/ctc_transformer.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, trg)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontext_representation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimage_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhead\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 111\u001b[0m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrearrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"b t y -> t b y\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n", - "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n", - "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mlinear\u001b[0;34m(input, weight, bias)\u001b[0m\n\u001b[1;32m 1691\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1692\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1693\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (238x128 and 256x56)", - "\nThe above exception was the direct cause of the following exception:\n", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m<ipython-input-8-85c5209ae40a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0msummary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m28\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m952\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdepth\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torchsummary/torchsummary.py\u001b[0m in \u001b[0;36msummary\u001b[0;34m(model, input_data, batch_dim, branching, col_names, col_width, depth, device, dtypes, verbose, *args, **kwargs)\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0mexecuted_layers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mlayer\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msummary_list\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecuted\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m raise RuntimeError(\n\u001b[0m\u001b[1;32m 128\u001b[0m \u001b[0;34m\"Failed to run torchsummary. See above stack traces for more details. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\"Executed layers up to: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexecuted_layers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mRuntimeError\u001b[0m: Failed to run torchsummary. See above stack traces for more details. Executed layers up to: [WideResidualNetwork: 1-1, Sequential: 2-1, Conv2d: 3-1, Sequential: 3-2, WideBlock: 4-1, Sequential: 3-3, WideBlock: 4-2, Sequential: 3-4, WideBlock: 4-3, Sequential: 3-5, WideBlock: 4-4, AdaptiveAvgPool2d: 1-2, Encoder: 1-3, EncoderLayer: 3-6, MultiHeadAttention: 4-5, _IntraLayerConnection: 4-6, _ConvolutionalLayer: 4-7, _IntraLayerConnection: 4-8, EncoderLayer: 3-7, MultiHeadAttention: 4-9, _IntraLayerConnection: 4-10, _ConvolutionalLayer: 4-11, _IntraLayerConnection: 4-12, LayerNorm: 2-2, Linear: 2-3, GLU: 2-4]" + "name": "stdout", + "output_type": "stream", + "text": [ + "===============================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "===============================================================================================\n", + "├─Encoder: 1-1 [-1, 32, 5, 59] --\n", + "| └─Sequential: 2-1 [-1, 32, 5, 59] --\n", + "| | └─Sequential: 3-1 [-1, 32, 14, 476] 544\n", + "| | └─Sequential: 3-2 [-1, 128, 7, 238] 65,664\n", + "| | └─Sequential: 3-3 [-1, 128, 6, 119] 262,272\n", + "| | └─Sequential: 3-4 [-1, 256, 5, 59] 524,544\n", + "| | └─_ResidualBlock: 3-5 [-1, 256, 5, 59] 655,360\n", + "| | └─_ResidualBlock: 3-6 [-1, 256, 5, 59] 655,360\n", + "| | └─Conv2d: 3-7 [-1, 32, 5, 59] 8,224\n", + "| └─VectorQuantizer: 2-2 [-1, 32, 5, 59] --\n", + "├─Decoder: 1-2 [-1, 1, 28, 952] --\n", + "| └─Sequential: 2-3 [-1, 1, 28, 952] --\n", + "| └─Sequential: 2-4 [-1, 256, 5, 59] --\n", + "| └─Sequential: 2 [] --\n", + "| | └─Sequential: 3-8 [-1, 256, 5, 59] (recursive)\n", + "| └─Sequential: 2 [] --\n", + "| | └─Conv2d: 3-9 [-1, 256, 5, 59] 8,448\n", + "| | └─_ResidualBlock: 3-10 [-1, 256, 5, 59] 655,360\n", + "| | └─_ResidualBlock: 3-11 [-1, 256, 5, 59] 655,360\n", + "| └─Sequential: 2-5 [-1, 1, 28, 952] --\n", + "| └─Sequential: 2 [] --\n", + "| | └─Sequential: 3-12 [-1, 1, 28, 952] (recursive)\n", + "| └─Sequential: 2 [] --\n", + "| | └─Sequential: 3-13 [-1, 128, 6, 118] 524,416\n", + "| | └─Upsample: 3-14 [-1, 128, 6, 119] --\n", + "| | └─Sequential: 3-15 [-1, 128, 7, 238] 262,272\n", + "| | └─Upsample: 3-16 [-1, 128, 7, 238] --\n", + "| | └─Sequential: 3-17 [-1, 32, 14, 476] 65,568\n", + "| | └─ConvTranspose2d: 3-18 [-1, 1, 28, 952] 513\n", + "| | └─Tanh: 3-19 [-1, 1, 28, 952] --\n", + "===============================================================================================\n", + "Total params: 4,343,905\n", + "Trainable params: 4,343,905\n", + "Non-trainable params: 0\n", + "Total mult-adds (G): 1.76\n", + "===============================================================================================\n", + "Input size (MB): 0.10\n", + "Forward/backward pass size (MB): 9.32\n", + "Params size (MB): 16.57\n", + "Estimated Total Size (MB): 26.00\n", + "===============================================================================================\n" ] + }, + { + "data": { + "text/plain": [ + "===============================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "===============================================================================================\n", + "├─Encoder: 1-1 [-1, 32, 5, 59] --\n", + "| └─Sequential: 2-1 [-1, 32, 5, 59] --\n", + "| | └─Sequential: 3-1 [-1, 32, 14, 476] 544\n", + "| | └─Sequential: 3-2 [-1, 128, 7, 238] 65,664\n", + "| | └─Sequential: 3-3 [-1, 128, 6, 119] 262,272\n", + "| | └─Sequential: 3-4 [-1, 256, 5, 59] 524,544\n", + "| | └─_ResidualBlock: 3-5 [-1, 256, 5, 59] 655,360\n", + "| | └─_ResidualBlock: 3-6 [-1, 256, 5, 59] 655,360\n", + "| | └─Conv2d: 3-7 [-1, 32, 5, 59] 8,224\n", + "| └─VectorQuantizer: 2-2 [-1, 32, 5, 59] --\n", + "├─Decoder: 1-2 [-1, 1, 28, 952] --\n", + "| └─Sequential: 2-3 [-1, 1, 28, 952] --\n", + "| └─Sequential: 2-4 [-1, 256, 5, 59] --\n", + "| └─Sequential: 2 [] --\n", + "| | └─Sequential: 3-8 [-1, 256, 5, 59] (recursive)\n", + "| └─Sequential: 2 [] --\n", + "| | └─Conv2d: 3-9 [-1, 256, 5, 59] 8,448\n", + "| | └─_ResidualBlock: 3-10 [-1, 256, 5, 59] 655,360\n", + "| | └─_ResidualBlock: 3-11 [-1, 256, 5, 59] 655,360\n", + "| └─Sequential: 2-5 [-1, 1, 28, 952] --\n", + "| └─Sequential: 2 [] --\n", + "| | └─Sequential: 3-12 [-1, 1, 28, 952] (recursive)\n", + "| └─Sequential: 2 [] --\n", + "| | └─Sequential: 3-13 [-1, 128, 6, 118] 524,416\n", + "| | └─Upsample: 3-14 [-1, 128, 6, 119] --\n", + "| | └─Sequential: 3-15 [-1, 128, 7, 238] 262,272\n", + "| | └─Upsample: 3-16 [-1, 128, 7, 238] --\n", + "| | └─Sequential: 3-17 [-1, 32, 14, 476] 65,568\n", + "| | └─ConvTranspose2d: 3-18 [-1, 1, 28, 952] 513\n", + "| | └─Tanh: 3-19 [-1, 1, 28, 952] --\n", + "===============================================================================================\n", + "Total params: 4,343,905\n", + "Trainable params: 4,343,905\n", + "Non-trainable params: 0\n", + "Total mult-adds (G): 1.76\n", + "===============================================================================================\n", + "Input size (MB): 0.10\n", + "Forward/backward pass size (MB): 9.32\n", + "Params size (MB): 16.57\n", + "Estimated Total Size (MB): 26.00\n", + "===============================================================================================" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "summary(vqvae, (1, 28, 952), device=\"cpu\", depth=3)" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": {}, + "outputs": [], + "source": [ + "up = nn.Upsample([4, 59])" + ] + }, + { + "cell_type": "code", + "execution_count": 107, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 32, 4, 59])" + ] + }, + "execution_count": 107, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "up(tt).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 32, 1, 59])" + ] + }, + "execution_count": 104, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "summary(model, (1, 28, 952), device=\"cpu\", depth=3)" + "tt.shape" ] }, { |