diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-07 20:10:54 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-07 20:10:54 +0100 |
commit | ff9a21d333f11a42e67c1963ed67de9c0fda87c9 (patch) | |
tree | afee959135416fe92cf6df377e84fb0a9e9714a0 /src/notebooks/00-testing-stuff-out.ipynb | |
parent | 25b5d6983d51e0e791b96a76beb7e49f392cd9a8 (diff) |
Minor updates.
Diffstat (limited to 'src/notebooks/00-testing-stuff-out.ipynb')
-rw-r--r-- | src/notebooks/00-testing-stuff-out.ipynb | 669 |
1 files changed, 523 insertions, 146 deletions
diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb index 6f01dfb..dd02098 100644 --- a/src/notebooks/00-testing-stuff-out.ipynb +++ b/src/notebooks/00-testing-stuff-out.ipynb @@ -2,18 +2,9 @@ "cells": [ { "cell_type": "code", - "execution_count": 21, + "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", @@ -22,6 +13,7 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from PIL import Image\n", + "import torch.nn.functional as F\n", "import torch\n", "from torch import nn\n", "from importlib.util import find_spec\n", @@ -32,74 +24,386 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, ResidualNetwork, ResidualNetworkEncoder" + "from text_recognizer.networks import CTCTransformer" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ - "from text_recognizer.networks import WideResidualNetwork" + "model = CTCTransformer(\n", + " num_encoder_layers=2,\n", + " hidden_dim=256,\n", + " vocab_size=56,\n", + " num_heads=8,\n", + " adaptive_pool_dim=[None, 1],\n", + " expansion_dim=2048,\n", + " dropout_rate=0.1,\n", + " max_len=256,\n", + " patch_size=(28, 32),\n", + " stride=(1, 28),\n", + " activation=\"gelu\",\n", + " backbone=\"WideResidualNetwork\",\n", + "backbone_args={\n", + " \"in_channels\": 1,\n", + " \"in_planes\": 64,\n", + " \"num_classes\": 80,\n", + " \"depth\": 10,\n", + " \"width_factor\": 1,\n", + " \"dropout_rate\": 0.1,\n", + " \"num_layers\": 4,\n", + " \"num_stages\": [64, 128, 256, 256],\n", + " \"activation\": \"elu\",\n", + " \"use_decoder\": False,\n", + "},\n", + " )" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "from pathlib import Path" + "backbone: WideResidualNetwork\n", + " backbone_args:\n", + " in_channels: 1\n", + " in_planes: 64\n", + " num_classes: 80\n", + " depth: 10\n", + " width_factor: 1\n", + " dropout_rate: 0.1\n", + " num_layers: 4 \n", + " num_stages: [64, 128, 256, 256]\n", + " activation: elu\n", + " use_decoder: false\n", + " n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "t = torch.randn(2, 1, 28, 952)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "True" + "torch.Size([119, 2, 56])" ] }, - "execution_count": 5, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "Path(\"/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/TransformerModel_EmnistLinesDataset_CNNTransformer/1112_081300/model/best.pt\").exists()" + "model(t).shape" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "Failed to run torchsummary. See above stack traces for more details. Executed layers up to: [WideResidualNetwork: 1-1, Sequential: 2-1, Conv2d: 3-1, Sequential: 3-2, WideBlock: 4-1, Sequential: 3-3, WideBlock: 4-2, Sequential: 3-4, WideBlock: 4-3, Sequential: 3-5, WideBlock: 4-4, AdaptiveAvgPool2d: 1-2, Encoder: 1-3, EncoderLayer: 3-6, MultiHeadAttention: 4-5, _IntraLayerConnection: 4-6, _ConvolutionalLayer: 4-7, _IntraLayerConnection: 4-8, EncoderLayer: 3-7, MultiHeadAttention: 4-9, _IntraLayerConnection: 4-10, _ConvolutionalLayer: 4-11, _IntraLayerConnection: 4-12, LayerNorm: 2-2, Linear: 2-3, GLU: 2-4]", + "output_type": "error", + "traceback": [ + "\u001b[0;31m----------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torchsummary/torchsummary.py\u001b[0m in \u001b[0;36msummary\u001b[0;34m(model, input_data, batch_dim, branching, col_names, col_width, depth, device, dtypes, verbose, *args, **kwargs)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 125\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n", + "\u001b[0;32m~/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/text_recognizer/networks/ctc_transformer.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, trg)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontext_representation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimage_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhead\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 111\u001b[0m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrearrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"b t y -> t b y\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n", + "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n", + "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mlinear\u001b[0;34m(input, weight, bias)\u001b[0m\n\u001b[1;32m 1691\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1692\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1693\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (238x128 and 256x56)", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-8-85c5209ae40a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0msummary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m28\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m952\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdepth\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torchsummary/torchsummary.py\u001b[0m in \u001b[0;36msummary\u001b[0;34m(model, input_data, batch_dim, branching, col_names, col_width, depth, device, dtypes, verbose, *args, **kwargs)\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0mexecuted_layers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mlayer\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msummary_list\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecuted\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m raise RuntimeError(\n\u001b[0m\u001b[1;32m 128\u001b[0m \u001b[0;34m\"Failed to run torchsummary. See above stack traces for more details. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\"Executed layers up to: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexecuted_layers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: Failed to run torchsummary. See above stack traces for more details. Executed layers up to: [WideResidualNetwork: 1-1, Sequential: 2-1, Conv2d: 3-1, Sequential: 3-2, WideBlock: 4-1, Sequential: 3-3, WideBlock: 4-2, Sequential: 3-4, WideBlock: 4-3, Sequential: 3-5, WideBlock: 4-4, AdaptiveAvgPool2d: 1-2, Encoder: 1-3, EncoderLayer: 3-6, MultiHeadAttention: 4-5, _IntraLayerConnection: 4-6, _ConvolutionalLayer: 4-7, _IntraLayerConnection: 4-8, EncoderLayer: 3-7, MultiHeadAttention: 4-9, _IntraLayerConnection: 4-10, _ConvolutionalLayer: 4-11, _IntraLayerConnection: 4-12, LayerNorm: 2-2, Linear: 2-3, GLU: 2-4]" + ] + } + ], + "source": [ + "summary(model, (1, 28, 952), device=\"cpu\", depth=3)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "class GEGLU(nn.Module):\n", + " def __init__(self, dim_in, dim_out):\n", + " super().__init__()\n", + " self.proj = nn.Linear(dim_in, dim_out * 2)\n", + "\n", + " def forward(self, x):\n", + " x, gate = self.proj(x).chunk(2, dim = -1)\n", + " return x * F.gelu(gate)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "e = GEGLU(256, 2048)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 30, 2048])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "e(t).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "emb = nn.Embedding(56, 256)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "with torch.no_grad():\n", + " e = emb(torch.Tensor([55]).long())" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "from einops import repeat" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "ee = repeat(e, \"() n -> b n\", b=16)" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleAttributeError", + "evalue": "'Embedding' object has no attribute 'device'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-58-657f11e4a017>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0memb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-N1c_zsdp-py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 776\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 777\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 778\u001b[0;31m raise ModuleAttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m 779\u001b[0m type(self).__name__, name))\n\u001b[1;32m 780\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleAttributeError\u001b[0m: 'Embedding' object has no attribute 'device'" + ] + } + ], + "source": [ + "emb.device" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n", + " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n", + " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n", + " ...,\n", + " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n", + " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n", + " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005]])" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ee" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([16, 256])" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ee.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [], + "source": [ + "t = torch.randn(16, 10, 256)" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([16, 10, 256])" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [], + "source": [ + "t = torch.cat((ee.unsqueeze(1), t, ee.unsqueeze(1)), dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([16, 12, 256])" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "False" + "torch.Size([1, 256])" ] }, - "execution_count": 6, + "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "Path(\"/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/TransformerModel_EmnistLinesDataset_CNNTransformer/1112_201649/model/best.pt\").exists()" + "e.shape" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, ResidualNetwork, ResidualNetworkEncoder" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks import WideResidualNetwork" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -109,16 +413,17 @@ " in_planes=64,\n", " depth=10,\n", " num_layers=4,\n", - " width_factor=1,\n", - " dropout_rate= 0.2,\n", + " width_factor=2,\n", + " num_stages=[64, 128, 256, 256],\n", + " dropout_rate= 0.1,\n", " activation= \"SELU\",\n", - " use_decoder= True,\n", + " use_decoder= False,\n", ")" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -127,16 +432,16 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 66, "metadata": {}, "outputs": [], "source": [ - "backbone = ResidualNetworkEncoder(1, [64, 128, 256], [2, 2, 3])" + "backbone = ResidualNetworkEncoder(1, [64, 65, 66, 67, 68], [2, 2, 2, 2, 2])" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 52, "metadata": {}, "outputs": [ { @@ -146,27 +451,31 @@ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", - "├─Sequential: 1-1 [-1, 64, 28, 952] --\n", - "| └─Conv2d: 2-1 [-1, 64, 28, 952] 576\n", - "| └─BatchNorm2d: 2-2 [-1, 64, 28, 952] 128\n", - "| └─ReLU: 2-3 [-1, 64, 28, 952] --\n", - "├─Sequential: 1-2 [-1, 256, 7, 238] --\n", - "| └─ResidualLayer: 2-4 [-1, 64, 28, 952] --\n", - "| | └─Sequential: 3-1 [-1, 64, 28, 952] 147,968\n", - "| └─ResidualLayer: 2-5 [-1, 128, 14, 476] --\n", - "| | └─Sequential: 3-2 [-1, 128, 14, 476] 525,568\n", - "| └─ResidualLayer: 2-6 [-1, 256, 7, 238] --\n", - "| | └─Sequential: 3-3 [-1, 256, 7, 238] 3,280,384\n", + "├─Sequential: 1-1 [-1, 64, 12, 474] --\n", + "| └─Conv2d: 2-1 [-1, 64, 12, 474] 3,136\n", + "| └─BatchNorm2d: 2-2 [-1, 64, 12, 474] 128\n", + "| └─ReLU: 2-3 [-1, 64, 12, 474] --\n", + "├─Sequential: 1-2 [-1, 68, 1, 30] --\n", + "| └─ResidualLayer: 2-4 [-1, 64, 12, 474] --\n", + "| | └─Sequential: 3-1 [-1, 64, 12, 474] 147,968\n", + "| └─ResidualLayer: 2-5 [-1, 65, 6, 237] --\n", + "| | └─Sequential: 3-2 [-1, 65, 6, 237] 156,325\n", + "| └─ResidualLayer: 2-6 [-1, 66, 3, 119] --\n", + "| | └─Sequential: 3-3 [-1, 66, 3, 119] 161,172\n", + "| └─ResidualLayer: 2-7 [-1, 67, 2, 60] --\n", + "| | └─Sequential: 3-4 [-1, 67, 2, 60] 166,093\n", + "| └─ResidualLayer: 2-8 [-1, 68, 1, 30] --\n", + "| | └─Sequential: 3-5 [-1, 68, 1, 30] 171,088\n", "==========================================================================================\n", - "Total params: 3,954,624\n", - "Trainable params: 3,954,624\n", + "Total params: 805,910\n", + "Trainable params: 805,910\n", "Non-trainable params: 0\n", - "Total mult-adds (M): 31.16\n", + "Total mult-adds (M): 21.05\n", "==========================================================================================\n", "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 26.03\n", - "Params size (MB): 15.09\n", - "Estimated Total Size (MB): 41.22\n", + "Forward/backward pass size (MB): 5.55\n", + "Params size (MB): 3.07\n", + "Estimated Total Size (MB): 8.73\n", "==========================================================================================\n" ] }, @@ -176,31 +485,35 @@ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", - "├─Sequential: 1-1 [-1, 64, 28, 952] --\n", - "| └─Conv2d: 2-1 [-1, 64, 28, 952] 576\n", - "| └─BatchNorm2d: 2-2 [-1, 64, 28, 952] 128\n", - "| └─ReLU: 2-3 [-1, 64, 28, 952] --\n", - "├─Sequential: 1-2 [-1, 256, 7, 238] --\n", - "| └─ResidualLayer: 2-4 [-1, 64, 28, 952] --\n", - "| | └─Sequential: 3-1 [-1, 64, 28, 952] 147,968\n", - "| └─ResidualLayer: 2-5 [-1, 128, 14, 476] --\n", - "| | └─Sequential: 3-2 [-1, 128, 14, 476] 525,568\n", - "| └─ResidualLayer: 2-6 [-1, 256, 7, 238] --\n", - "| | └─Sequential: 3-3 [-1, 256, 7, 238] 3,280,384\n", + "├─Sequential: 1-1 [-1, 64, 12, 474] --\n", + "| └─Conv2d: 2-1 [-1, 64, 12, 474] 3,136\n", + "| └─BatchNorm2d: 2-2 [-1, 64, 12, 474] 128\n", + "| └─ReLU: 2-3 [-1, 64, 12, 474] --\n", + "├─Sequential: 1-2 [-1, 68, 1, 30] --\n", + "| └─ResidualLayer: 2-4 [-1, 64, 12, 474] --\n", + "| | └─Sequential: 3-1 [-1, 64, 12, 474] 147,968\n", + "| └─ResidualLayer: 2-5 [-1, 65, 6, 237] --\n", + "| | └─Sequential: 3-2 [-1, 65, 6, 237] 156,325\n", + "| └─ResidualLayer: 2-6 [-1, 66, 3, 119] --\n", + "| | └─Sequential: 3-3 [-1, 66, 3, 119] 161,172\n", + "| └─ResidualLayer: 2-7 [-1, 67, 2, 60] --\n", + "| | └─Sequential: 3-4 [-1, 67, 2, 60] 166,093\n", + "| └─ResidualLayer: 2-8 [-1, 68, 1, 30] --\n", + "| | └─Sequential: 3-5 [-1, 68, 1, 30] 171,088\n", "==========================================================================================\n", - "Total params: 3,954,624\n", - "Trainable params: 3,954,624\n", + "Total params: 805,910\n", + "Trainable params: 805,910\n", "Non-trainable params: 0\n", - "Total mult-adds (M): 31.16\n", + "Total mult-adds (M): 21.05\n", "==========================================================================================\n", "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 26.03\n", - "Params size (MB): 15.09\n", - "Estimated Total Size (MB): 41.22\n", + "Forward/backward pass size (MB): 5.55\n", + "Params size (MB): 3.07\n", + "Estimated Total Size (MB): 8.73\n", "==========================================================================================" ] }, - "execution_count": 20, + "execution_count": 52, "metadata": {}, "output_type": "execute_result" } @@ -211,7 +524,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -222,7 +535,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -231,29 +544,34 @@ "Sequential(\n", " (0): SELU(inplace=True)\n", " (1): Sequential(\n", - " (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): Sequential(\n", + " (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): SELU(inplace=True)\n", + " (3): MaxPool2d(kernel_size=(2, 4), stride=2, padding=1, dilation=1, ceil_mode=False)\n", + " )\n", + " (2): Sequential(\n", + " (0): Sequential(\n", " (0): WideBlock(\n", " (activation): SELU(inplace=True)\n", " (blocks): Sequential(\n", " (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (1): SELU(inplace=True)\n", " (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (3): Dropout(p=0.2, inplace=False)\n", + " (3): Dropout(p=0.1, inplace=False)\n", " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (5): SELU(inplace=True)\n", " (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " )\n", " )\n", " )\n", - " (2): Sequential(\n", + " (1): Sequential(\n", " (0): WideBlock(\n", " (activation): SELU(inplace=True)\n", " (blocks): Sequential(\n", " (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (1): SELU(inplace=True)\n", " (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (3): Dropout(p=0.2, inplace=False)\n", + " (3): Dropout(p=0.1, inplace=False)\n", " (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (5): SELU(inplace=True)\n", " (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", @@ -263,14 +581,14 @@ " )\n", " )\n", " )\n", - " (3): Sequential(\n", + " (2): Sequential(\n", " (0): WideBlock(\n", " (activation): SELU(inplace=True)\n", " (blocks): Sequential(\n", " (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (1): SELU(inplace=True)\n", " (2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (3): Dropout(p=0.2, inplace=False)\n", + " (3): Dropout(p=0.1, inplace=False)\n", " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (5): SELU(inplace=True)\n", " (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", @@ -280,34 +598,28 @@ " )\n", " )\n", " )\n", - " (4): Sequential(\n", + " (3): Sequential(\n", " (0): WideBlock(\n", " (activation): SELU(inplace=True)\n", " (blocks): Sequential(\n", " (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (1): SELU(inplace=True)\n", - " (2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (3): Dropout(p=0.1, inplace=False)\n", + " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (5): SELU(inplace=True)\n", - " (6): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " )\n", " (shortcut): Sequential(\n", - " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", + " (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " )\n", " )\n", " )\n", " )\n", - " (2): Sequential(\n", - " (0): BatchNorm2d(512, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True)\n", - " (1): SELU(inplace=True)\n", - " (2): Reduce('b c h w -> b c', 'mean')\n", - " (3): Linear(in_features=512, out_features=80, bias=True)\n", - " )\n", ")" ] }, - "execution_count": 25, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -318,7 +630,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -328,36 +640,32 @@ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", - "├─SELU: 1-1 [-1, 1, 28, 952] --\n", - "├─Sequential: 1 [] --\n", - "| └─SELU: 2-1 [-1, 1, 28, 952] --\n", - "├─Sequential: 1-2 [-1, 512, 4, 119] --\n", - "| └─Conv2d: 2-2 [-1, 64, 28, 952] 576\n", - "| └─Sequential: 2-3 [-1, 64, 28, 952] --\n", - "| | └─WideBlock: 3-1 [-1, 64, 28, 952] 73,984\n", - "| └─Sequential: 2-4 [-1, 128, 14, 476] --\n", - "| | └─WideBlock: 3-2 [-1, 128, 14, 476] 229,760\n", - "| └─Sequential: 2-5 [-1, 256, 7, 238] --\n", - "| | └─WideBlock: 3-3 [-1, 256, 7, 238] 918,272\n", - "| └─Sequential: 2-6 [-1, 512, 4, 119] --\n", - "| | └─WideBlock: 3-4 [-1, 512, 4, 119] 3,671,552\n", - "├─Sequential: 1-3 [-1, 80] --\n", - "| └─BatchNorm2d: 2-7 [-1, 512, 4, 119] 1,024\n", - "├─SELU: 1-4 [-1, 512, 4, 119] --\n", + "├─Sequential: 1-1 [-1, 64, 7, 237] --\n", + "| └─Conv2d: 2-1 [-1, 64, 12, 474] 3,136\n", + "| └─BatchNorm2d: 2-2 [-1, 64, 12, 474] 128\n", + "├─SELU: 1-2 [-1, 64, 12, 474] --\n", "├─Sequential: 1 [] --\n", - "| └─SELU: 2-8 [-1, 512, 4, 119] --\n", - "| └─Reduce: 2-9 [-1, 512] --\n", - "| └─Linear: 2-10 [-1, 80] 41,040\n", + "| └─SELU: 2-3 [-1, 64, 12, 474] --\n", + "| └─MaxPool2d: 2-4 [-1, 64, 7, 237] --\n", + "├─Sequential: 1-3 [-1, 256, 1, 30] --\n", + "| └─Sequential: 2-5 [-1, 64, 7, 237] --\n", + "| | └─WideBlock: 3-1 [-1, 64, 7, 237] 73,984\n", + "| └─Sequential: 2-6 [-1, 128, 4, 119] --\n", + "| | └─WideBlock: 3-2 [-1, 128, 4, 119] 229,760\n", + "| └─Sequential: 2-7 [-1, 256, 2, 60] --\n", + "| | └─WideBlock: 3-3 [-1, 256, 2, 60] 918,272\n", + "| └─Sequential: 2-8 [-1, 256, 1, 30] --\n", + "| | └─WideBlock: 3-4 [-1, 256, 1, 30] 1,246,208\n", "==========================================================================================\n", - "Total params: 4,936,208\n", - "Trainable params: 4,936,208\n", + "Total params: 2,471,488\n", + "Trainable params: 2,471,488\n", "Non-trainable params: 0\n", - "Total mult-adds (M): 35.01\n", + "Total mult-adds (M): 27.71\n", "==========================================================================================\n", "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 14.88\n", - "Params size (MB): 18.83\n", - "Estimated Total Size (MB): 33.81\n", + "Forward/backward pass size (MB): 5.55\n", + "Params size (MB): 9.43\n", + "Estimated Total Size (MB): 15.08\n", "==========================================================================================\n" ] }, @@ -367,51 +675,47 @@ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", - "├─SELU: 1-1 [-1, 1, 28, 952] --\n", + "├─Sequential: 1-1 [-1, 64, 7, 237] --\n", + "| └─Conv2d: 2-1 [-1, 64, 12, 474] 3,136\n", + "| └─BatchNorm2d: 2-2 [-1, 64, 12, 474] 128\n", + "├─SELU: 1-2 [-1, 64, 12, 474] --\n", "├─Sequential: 1 [] --\n", - "| └─SELU: 2-1 [-1, 1, 28, 952] --\n", - "├─Sequential: 1-2 [-1, 512, 4, 119] --\n", - "| └─Conv2d: 2-2 [-1, 64, 28, 952] 576\n", - "| └─Sequential: 2-3 [-1, 64, 28, 952] --\n", - "| | └─WideBlock: 3-1 [-1, 64, 28, 952] 73,984\n", - "| └─Sequential: 2-4 [-1, 128, 14, 476] --\n", - "| | └─WideBlock: 3-2 [-1, 128, 14, 476] 229,760\n", - "| └─Sequential: 2-5 [-1, 256, 7, 238] --\n", - "| | └─WideBlock: 3-3 [-1, 256, 7, 238] 918,272\n", - "| └─Sequential: 2-6 [-1, 512, 4, 119] --\n", - "| | └─WideBlock: 3-4 [-1, 512, 4, 119] 3,671,552\n", - "├─Sequential: 1-3 [-1, 80] --\n", - "| └─BatchNorm2d: 2-7 [-1, 512, 4, 119] 1,024\n", - "├─SELU: 1-4 [-1, 512, 4, 119] --\n", - "├─Sequential: 1 [] --\n", - "| └─SELU: 2-8 [-1, 512, 4, 119] --\n", - "| └─Reduce: 2-9 [-1, 512] --\n", - "| └─Linear: 2-10 [-1, 80] 41,040\n", + "| └─SELU: 2-3 [-1, 64, 12, 474] --\n", + "| └─MaxPool2d: 2-4 [-1, 64, 7, 237] --\n", + "├─Sequential: 1-3 [-1, 256, 1, 30] --\n", + "| └─Sequential: 2-5 [-1, 64, 7, 237] --\n", + "| | └─WideBlock: 3-1 [-1, 64, 7, 237] 73,984\n", + "| └─Sequential: 2-6 [-1, 128, 4, 119] --\n", + "| | └─WideBlock: 3-2 [-1, 128, 4, 119] 229,760\n", + "| └─Sequential: 2-7 [-1, 256, 2, 60] --\n", + "| | └─WideBlock: 3-3 [-1, 256, 2, 60] 918,272\n", + "| └─Sequential: 2-8 [-1, 256, 1, 30] --\n", + "| | └─WideBlock: 3-4 [-1, 256, 1, 30] 1,246,208\n", "==========================================================================================\n", - "Total params: 4,936,208\n", - "Trainable params: 4,936,208\n", + "Total params: 2,471,488\n", + "Trainable params: 2,471,488\n", "Non-trainable params: 0\n", - "Total mult-adds (M): 35.01\n", + "Total mult-adds (M): 27.71\n", "==========================================================================================\n", "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 14.88\n", - "Params size (MB): 18.83\n", - "Estimated Total Size (MB): 33.81\n", + "Forward/backward pass size (MB): 5.55\n", + "Params size (MB): 9.43\n", + "Estimated Total Size (MB): 15.08\n", "==========================================================================================" ] }, - "execution_count": 26, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "summary(backbone, (1, 28, 952), device=\"cpu\", depth=3)" + "summary(wr, (1, 28, 952), device=\"cpu\", depth=3)" ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ @@ -1131,16 +1435,89 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ - "pred = torch.Tensor([1,1,1,1,1, 81, 1, 79, 79, 79, 2,1,1,1,1, 81, 1, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()\n", + "pred = torch.Tensor([1,21,2,45,31, 81, 1, 79, 79, 79, 2,1,1,1,1, 81, 1, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()\n", "target = torch.Tensor([1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()" ] }, { "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "mask = (target != 79)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ True, True, True, True, True, True, False, False, False, False,\n", + " True, True, True, True, True, True, False, False, False, False,\n", + " True, True, True, True, True, True, False, False, False, False])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mask" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 1, 21, 2, 45, 31, 81, 0, 0, 0, 0, 2, 1, 1, 1, 1, 81, 0, 0,\n", + " 0, 0, 1, 1, 1, 1, 1, 81, 0, 0, 0, 0])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pred * mask" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 1, 1, 1, 1, 1, 81, 0, 0, 0, 0, 1, 1, 1, 1, 1, 81, 0, 0,\n", + " 0, 0, 1, 1, 1, 1, 1, 81, 0, 0, 0, 0])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "target * mask" + ] + }, + { + "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], |