summaryrefslogtreecommitdiff
path: root/src/notebooks/00-testing-stuff-out.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'src/notebooks/00-testing-stuff-out.ipynb')
-rw-r--r--src/notebooks/00-testing-stuff-out.ipynb669
1 files changed, 523 insertions, 146 deletions
diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb
index 6f01dfb..dd02098 100644
--- a/src/notebooks/00-testing-stuff-out.ipynb
+++ b/src/notebooks/00-testing-stuff-out.ipynb
@@ -2,18 +2,9 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 1,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "The autoreload extension is already loaded. To reload it, use:\n",
- " %reload_ext autoreload\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
@@ -22,6 +13,7 @@
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from PIL import Image\n",
+ "import torch.nn.functional as F\n",
"import torch\n",
"from torch import nn\n",
"from importlib.util import find_spec\n",
@@ -32,74 +24,386 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
- "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, ResidualNetwork, ResidualNetworkEncoder"
+ "from text_recognizer.networks import CTCTransformer"
]
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
- "from text_recognizer.networks import WideResidualNetwork"
+ "model = CTCTransformer(\n",
+ " num_encoder_layers=2,\n",
+ " hidden_dim=256,\n",
+ " vocab_size=56,\n",
+ " num_heads=8,\n",
+ " adaptive_pool_dim=[None, 1],\n",
+ " expansion_dim=2048,\n",
+ " dropout_rate=0.1,\n",
+ " max_len=256,\n",
+ " patch_size=(28, 32),\n",
+ " stride=(1, 28),\n",
+ " activation=\"gelu\",\n",
+ " backbone=\"WideResidualNetwork\",\n",
+ "backbone_args={\n",
+ " \"in_channels\": 1,\n",
+ " \"in_planes\": 64,\n",
+ " \"num_classes\": 80,\n",
+ " \"depth\": 10,\n",
+ " \"width_factor\": 1,\n",
+ " \"dropout_rate\": 0.1,\n",
+ " \"num_layers\": 4,\n",
+ " \"num_stages\": [64, 128, 256, 256],\n",
+ " \"activation\": \"elu\",\n",
+ " \"use_decoder\": False,\n",
+ "},\n",
+ " )"
]
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "from pathlib import Path"
+ "backbone: WideResidualNetwork\n",
+ " backbone_args:\n",
+ " in_channels: 1\n",
+ " in_planes: 64\n",
+ " num_classes: 80\n",
+ " depth: 10\n",
+ " width_factor: 1\n",
+ " dropout_rate: 0.1\n",
+ " num_layers: 4 \n",
+ " num_stages: [64, 128, 256, 256]\n",
+ " activation: elu\n",
+ " use_decoder: false\n",
+ " n"
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = torch.randn(2, 1, 28, 952)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "True"
+ "torch.Size([119, 2, 56])"
]
},
- "execution_count": 5,
+ "execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "Path(\"/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/TransformerModel_EmnistLinesDataset_CNNTransformer/1112_081300/model/best.pt\").exists()"
+ "model(t).shape"
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "RuntimeError",
+ "evalue": "Failed to run torchsummary. See above stack traces for more details. Executed layers up to: [WideResidualNetwork: 1-1, Sequential: 2-1, Conv2d: 3-1, Sequential: 3-2, WideBlock: 4-1, Sequential: 3-3, WideBlock: 4-2, Sequential: 3-4, WideBlock: 4-3, Sequential: 3-5, WideBlock: 4-4, AdaptiveAvgPool2d: 1-2, Encoder: 1-3, EncoderLayer: 3-6, MultiHeadAttention: 4-5, _IntraLayerConnection: 4-6, _ConvolutionalLayer: 4-7, _IntraLayerConnection: 4-8, EncoderLayer: 3-7, MultiHeadAttention: 4-9, _IntraLayerConnection: 4-10, _ConvolutionalLayer: 4-11, _IntraLayerConnection: 4-12, LayerNorm: 2-2, Linear: 2-3, GLU: 2-4]",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m----------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torchsummary/torchsummary.py\u001b[0m in \u001b[0;36msummary\u001b[0;34m(model, input_data, batch_dim, branching, col_names, col_width, depth, device, dtypes, verbose, *args, **kwargs)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 125\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n",
+ "\u001b[0;32m~/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/text_recognizer/networks/ctc_transformer.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, trg)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontext_representation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimage_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhead\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 111\u001b[0m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrearrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"b t y -> t b y\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n",
+ "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n",
+ "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mlinear\u001b[0;34m(input, weight, bias)\u001b[0m\n\u001b[1;32m 1691\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1692\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1693\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (238x128 and 256x56)",
+ "\nThe above exception was the direct cause of the following exception:\n",
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m<ipython-input-8-85c5209ae40a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0msummary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m28\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m952\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdepth\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;32m~/.pyenv/versions/3.8.2/envs/text-recognizer/lib/python3.8/site-packages/torchsummary/torchsummary.py\u001b[0m in \u001b[0;36msummary\u001b[0;34m(model, input_data, batch_dim, branching, col_names, col_width, depth, device, dtypes, verbose, *args, **kwargs)\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0mexecuted_layers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mlayer\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msummary_list\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecuted\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m raise RuntimeError(\n\u001b[0m\u001b[1;32m 128\u001b[0m \u001b[0;34m\"Failed to run torchsummary. See above stack traces for more details. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\"Executed layers up to: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexecuted_layers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mRuntimeError\u001b[0m: Failed to run torchsummary. See above stack traces for more details. Executed layers up to: [WideResidualNetwork: 1-1, Sequential: 2-1, Conv2d: 3-1, Sequential: 3-2, WideBlock: 4-1, Sequential: 3-3, WideBlock: 4-2, Sequential: 3-4, WideBlock: 4-3, Sequential: 3-5, WideBlock: 4-4, AdaptiveAvgPool2d: 1-2, Encoder: 1-3, EncoderLayer: 3-6, MultiHeadAttention: 4-5, _IntraLayerConnection: 4-6, _ConvolutionalLayer: 4-7, _IntraLayerConnection: 4-8, EncoderLayer: 3-7, MultiHeadAttention: 4-9, _IntraLayerConnection: 4-10, _ConvolutionalLayer: 4-11, _IntraLayerConnection: 4-12, LayerNorm: 2-2, Linear: 2-3, GLU: 2-4]"
+ ]
+ }
+ ],
+ "source": [
+ "summary(model, (1, 28, 952), device=\"cpu\", depth=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class GEGLU(nn.Module):\n",
+ " def __init__(self, dim_in, dim_out):\n",
+ " super().__init__()\n",
+ " self.proj = nn.Linear(dim_in, dim_out * 2)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x, gate = self.proj(x).chunk(2, dim = -1)\n",
+ " return x * F.gelu(gate)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "e = GEGLU(256, 2048)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 30, 2048])"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "e(t).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "emb = nn.Embedding(56, 256)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with torch.no_grad():\n",
+ " e = emb(torch.Tensor([55]).long())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from einops import repeat"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ee = repeat(e, \"() n -> b n\", b=16)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 58,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "ModuleAttributeError",
+ "evalue": "'Embedding' object has no attribute 'device'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mModuleAttributeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m<ipython-input-58-657f11e4a017>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0memb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-N1c_zsdp-py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 776\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 777\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 778\u001b[0;31m raise ModuleAttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m 779\u001b[0m type(self).__name__, name))\n\u001b[1;32m 780\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mModuleAttributeError\u001b[0m: 'Embedding' object has no attribute 'device'"
+ ]
+ }
+ ],
+ "source": [
+ "emb.device"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n",
+ " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n",
+ " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n",
+ " ...,\n",
+ " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n",
+ " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005],\n",
+ " [-1.0624, 0.0674, 0.9387, ..., -0.1852, -0.1303, 0.8005]])"
+ ]
+ },
+ "execution_count": 49,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ee"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([16, 256])"
+ ]
+ },
+ "execution_count": 47,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ee.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 50,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = torch.randn(16, 10, 256)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([16, 10, 256])"
+ ]
+ },
+ "execution_count": 52,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 56,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = torch.cat((ee.unsqueeze(1), t, ee.unsqueeze(1)), dim=1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 57,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([16, 12, 256])"
+ ]
+ },
+ "execution_count": 57,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "False"
+ "torch.Size([1, 256])"
]
},
- "execution_count": 6,
+ "execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "Path(\"/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/TransformerModel_EmnistLinesDataset_CNNTransformer/1112_201649/model/best.pt\").exists()"
+ "e.shape"
]
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, ResidualNetwork, ResidualNetworkEncoder"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks import WideResidualNetwork"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@@ -109,16 +413,17 @@
" in_planes=64,\n",
" depth=10,\n",
" num_layers=4,\n",
- " width_factor=1,\n",
- " dropout_rate= 0.2,\n",
+ " width_factor=2,\n",
+ " num_stages=[64, 128, 256, 256],\n",
+ " dropout_rate= 0.1,\n",
" activation= \"SELU\",\n",
- " use_decoder= True,\n",
+ " use_decoder= False,\n",
")"
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@@ -127,16 +432,16 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
- "backbone = ResidualNetworkEncoder(1, [64, 128, 256], [2, 2, 3])"
+ "backbone = ResidualNetworkEncoder(1, [64, 65, 66, 67, 68], [2, 2, 2, 2, 2])"
]
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 52,
"metadata": {},
"outputs": [
{
@@ -146,27 +451,31 @@
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
- "├─Sequential: 1-1 [-1, 64, 28, 952] --\n",
- "| └─Conv2d: 2-1 [-1, 64, 28, 952] 576\n",
- "| └─BatchNorm2d: 2-2 [-1, 64, 28, 952] 128\n",
- "| └─ReLU: 2-3 [-1, 64, 28, 952] --\n",
- "├─Sequential: 1-2 [-1, 256, 7, 238] --\n",
- "| └─ResidualLayer: 2-4 [-1, 64, 28, 952] --\n",
- "| | └─Sequential: 3-1 [-1, 64, 28, 952] 147,968\n",
- "| └─ResidualLayer: 2-5 [-1, 128, 14, 476] --\n",
- "| | └─Sequential: 3-2 [-1, 128, 14, 476] 525,568\n",
- "| └─ResidualLayer: 2-6 [-1, 256, 7, 238] --\n",
- "| | └─Sequential: 3-3 [-1, 256, 7, 238] 3,280,384\n",
+ "├─Sequential: 1-1 [-1, 64, 12, 474] --\n",
+ "| └─Conv2d: 2-1 [-1, 64, 12, 474] 3,136\n",
+ "| └─BatchNorm2d: 2-2 [-1, 64, 12, 474] 128\n",
+ "| └─ReLU: 2-3 [-1, 64, 12, 474] --\n",
+ "├─Sequential: 1-2 [-1, 68, 1, 30] --\n",
+ "| └─ResidualLayer: 2-4 [-1, 64, 12, 474] --\n",
+ "| | └─Sequential: 3-1 [-1, 64, 12, 474] 147,968\n",
+ "| └─ResidualLayer: 2-5 [-1, 65, 6, 237] --\n",
+ "| | └─Sequential: 3-2 [-1, 65, 6, 237] 156,325\n",
+ "| └─ResidualLayer: 2-6 [-1, 66, 3, 119] --\n",
+ "| | └─Sequential: 3-3 [-1, 66, 3, 119] 161,172\n",
+ "| └─ResidualLayer: 2-7 [-1, 67, 2, 60] --\n",
+ "| | └─Sequential: 3-4 [-1, 67, 2, 60] 166,093\n",
+ "| └─ResidualLayer: 2-8 [-1, 68, 1, 30] --\n",
+ "| | └─Sequential: 3-5 [-1, 68, 1, 30] 171,088\n",
"==========================================================================================\n",
- "Total params: 3,954,624\n",
- "Trainable params: 3,954,624\n",
+ "Total params: 805,910\n",
+ "Trainable params: 805,910\n",
"Non-trainable params: 0\n",
- "Total mult-adds (M): 31.16\n",
+ "Total mult-adds (M): 21.05\n",
"==========================================================================================\n",
"Input size (MB): 0.10\n",
- "Forward/backward pass size (MB): 26.03\n",
- "Params size (MB): 15.09\n",
- "Estimated Total Size (MB): 41.22\n",
+ "Forward/backward pass size (MB): 5.55\n",
+ "Params size (MB): 3.07\n",
+ "Estimated Total Size (MB): 8.73\n",
"==========================================================================================\n"
]
},
@@ -176,31 +485,35 @@
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
- "├─Sequential: 1-1 [-1, 64, 28, 952] --\n",
- "| └─Conv2d: 2-1 [-1, 64, 28, 952] 576\n",
- "| └─BatchNorm2d: 2-2 [-1, 64, 28, 952] 128\n",
- "| └─ReLU: 2-3 [-1, 64, 28, 952] --\n",
- "├─Sequential: 1-2 [-1, 256, 7, 238] --\n",
- "| └─ResidualLayer: 2-4 [-1, 64, 28, 952] --\n",
- "| | └─Sequential: 3-1 [-1, 64, 28, 952] 147,968\n",
- "| └─ResidualLayer: 2-5 [-1, 128, 14, 476] --\n",
- "| | └─Sequential: 3-2 [-1, 128, 14, 476] 525,568\n",
- "| └─ResidualLayer: 2-6 [-1, 256, 7, 238] --\n",
- "| | └─Sequential: 3-3 [-1, 256, 7, 238] 3,280,384\n",
+ "├─Sequential: 1-1 [-1, 64, 12, 474] --\n",
+ "| └─Conv2d: 2-1 [-1, 64, 12, 474] 3,136\n",
+ "| └─BatchNorm2d: 2-2 [-1, 64, 12, 474] 128\n",
+ "| └─ReLU: 2-3 [-1, 64, 12, 474] --\n",
+ "├─Sequential: 1-2 [-1, 68, 1, 30] --\n",
+ "| └─ResidualLayer: 2-4 [-1, 64, 12, 474] --\n",
+ "| | └─Sequential: 3-1 [-1, 64, 12, 474] 147,968\n",
+ "| └─ResidualLayer: 2-5 [-1, 65, 6, 237] --\n",
+ "| | └─Sequential: 3-2 [-1, 65, 6, 237] 156,325\n",
+ "| └─ResidualLayer: 2-6 [-1, 66, 3, 119] --\n",
+ "| | └─Sequential: 3-3 [-1, 66, 3, 119] 161,172\n",
+ "| └─ResidualLayer: 2-7 [-1, 67, 2, 60] --\n",
+ "| | └─Sequential: 3-4 [-1, 67, 2, 60] 166,093\n",
+ "| └─ResidualLayer: 2-8 [-1, 68, 1, 30] --\n",
+ "| | └─Sequential: 3-5 [-1, 68, 1, 30] 171,088\n",
"==========================================================================================\n",
- "Total params: 3,954,624\n",
- "Trainable params: 3,954,624\n",
+ "Total params: 805,910\n",
+ "Trainable params: 805,910\n",
"Non-trainable params: 0\n",
- "Total mult-adds (M): 31.16\n",
+ "Total mult-adds (M): 21.05\n",
"==========================================================================================\n",
"Input size (MB): 0.10\n",
- "Forward/backward pass size (MB): 26.03\n",
- "Params size (MB): 15.09\n",
- "Estimated Total Size (MB): 41.22\n",
+ "Forward/backward pass size (MB): 5.55\n",
+ "Params size (MB): 3.07\n",
+ "Estimated Total Size (MB): 8.73\n",
"=========================================================================================="
]
},
- "execution_count": 20,
+ "execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
@@ -211,7 +524,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -222,7 +535,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
@@ -231,29 +544,34 @@
"Sequential(\n",
" (0): SELU(inplace=True)\n",
" (1): Sequential(\n",
- " (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): Sequential(\n",
+ " (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (2): SELU(inplace=True)\n",
+ " (3): MaxPool2d(kernel_size=(2, 4), stride=2, padding=1, dilation=1, ceil_mode=False)\n",
+ " )\n",
+ " (2): Sequential(\n",
+ " (0): Sequential(\n",
" (0): WideBlock(\n",
" (activation): SELU(inplace=True)\n",
" (blocks): Sequential(\n",
" (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (1): SELU(inplace=True)\n",
" (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.2, inplace=False)\n",
+ " (3): Dropout(p=0.1, inplace=False)\n",
" (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (5): SELU(inplace=True)\n",
" (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" )\n",
" )\n",
" )\n",
- " (2): Sequential(\n",
+ " (1): Sequential(\n",
" (0): WideBlock(\n",
" (activation): SELU(inplace=True)\n",
" (blocks): Sequential(\n",
" (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (1): SELU(inplace=True)\n",
" (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.2, inplace=False)\n",
+ " (3): Dropout(p=0.1, inplace=False)\n",
" (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (5): SELU(inplace=True)\n",
" (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
@@ -263,14 +581,14 @@
" )\n",
" )\n",
" )\n",
- " (3): Sequential(\n",
+ " (2): Sequential(\n",
" (0): WideBlock(\n",
" (activation): SELU(inplace=True)\n",
" (blocks): Sequential(\n",
" (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (1): SELU(inplace=True)\n",
" (2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.2, inplace=False)\n",
+ " (3): Dropout(p=0.1, inplace=False)\n",
" (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (5): SELU(inplace=True)\n",
" (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
@@ -280,34 +598,28 @@
" )\n",
" )\n",
" )\n",
- " (4): Sequential(\n",
+ " (3): Sequential(\n",
" (0): WideBlock(\n",
" (activation): SELU(inplace=True)\n",
" (blocks): Sequential(\n",
" (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (1): SELU(inplace=True)\n",
- " (2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.2, inplace=False)\n",
- " (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (3): Dropout(p=0.1, inplace=False)\n",
+ " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (5): SELU(inplace=True)\n",
- " (6): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
+ " (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" )\n",
" (shortcut): Sequential(\n",
- " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
+ " (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
- " (2): Sequential(\n",
- " (0): BatchNorm2d(512, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True)\n",
- " (1): SELU(inplace=True)\n",
- " (2): Reduce('b c h w -> b c', 'mean')\n",
- " (3): Linear(in_features=512, out_features=80, bias=True)\n",
- " )\n",
")"
]
},
- "execution_count": 25,
+ "execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@@ -318,7 +630,7 @@
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 11,
"metadata": {},
"outputs": [
{
@@ -328,36 +640,32 @@
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
- "├─SELU: 1-1 [-1, 1, 28, 952] --\n",
- "├─Sequential: 1 [] --\n",
- "| └─SELU: 2-1 [-1, 1, 28, 952] --\n",
- "├─Sequential: 1-2 [-1, 512, 4, 119] --\n",
- "| └─Conv2d: 2-2 [-1, 64, 28, 952] 576\n",
- "| └─Sequential: 2-3 [-1, 64, 28, 952] --\n",
- "| | └─WideBlock: 3-1 [-1, 64, 28, 952] 73,984\n",
- "| └─Sequential: 2-4 [-1, 128, 14, 476] --\n",
- "| | └─WideBlock: 3-2 [-1, 128, 14, 476] 229,760\n",
- "| └─Sequential: 2-5 [-1, 256, 7, 238] --\n",
- "| | └─WideBlock: 3-3 [-1, 256, 7, 238] 918,272\n",
- "| └─Sequential: 2-6 [-1, 512, 4, 119] --\n",
- "| | └─WideBlock: 3-4 [-1, 512, 4, 119] 3,671,552\n",
- "├─Sequential: 1-3 [-1, 80] --\n",
- "| └─BatchNorm2d: 2-7 [-1, 512, 4, 119] 1,024\n",
- "├─SELU: 1-4 [-1, 512, 4, 119] --\n",
+ "├─Sequential: 1-1 [-1, 64, 7, 237] --\n",
+ "| └─Conv2d: 2-1 [-1, 64, 12, 474] 3,136\n",
+ "| └─BatchNorm2d: 2-2 [-1, 64, 12, 474] 128\n",
+ "├─SELU: 1-2 [-1, 64, 12, 474] --\n",
"├─Sequential: 1 [] --\n",
- "| └─SELU: 2-8 [-1, 512, 4, 119] --\n",
- "| └─Reduce: 2-9 [-1, 512] --\n",
- "| └─Linear: 2-10 [-1, 80] 41,040\n",
+ "| └─SELU: 2-3 [-1, 64, 12, 474] --\n",
+ "| └─MaxPool2d: 2-4 [-1, 64, 7, 237] --\n",
+ "├─Sequential: 1-3 [-1, 256, 1, 30] --\n",
+ "| └─Sequential: 2-5 [-1, 64, 7, 237] --\n",
+ "| | └─WideBlock: 3-1 [-1, 64, 7, 237] 73,984\n",
+ "| └─Sequential: 2-6 [-1, 128, 4, 119] --\n",
+ "| | └─WideBlock: 3-2 [-1, 128, 4, 119] 229,760\n",
+ "| └─Sequential: 2-7 [-1, 256, 2, 60] --\n",
+ "| | └─WideBlock: 3-3 [-1, 256, 2, 60] 918,272\n",
+ "| └─Sequential: 2-8 [-1, 256, 1, 30] --\n",
+ "| | └─WideBlock: 3-4 [-1, 256, 1, 30] 1,246,208\n",
"==========================================================================================\n",
- "Total params: 4,936,208\n",
- "Trainable params: 4,936,208\n",
+ "Total params: 2,471,488\n",
+ "Trainable params: 2,471,488\n",
"Non-trainable params: 0\n",
- "Total mult-adds (M): 35.01\n",
+ "Total mult-adds (M): 27.71\n",
"==========================================================================================\n",
"Input size (MB): 0.10\n",
- "Forward/backward pass size (MB): 14.88\n",
- "Params size (MB): 18.83\n",
- "Estimated Total Size (MB): 33.81\n",
+ "Forward/backward pass size (MB): 5.55\n",
+ "Params size (MB): 9.43\n",
+ "Estimated Total Size (MB): 15.08\n",
"==========================================================================================\n"
]
},
@@ -367,51 +675,47 @@
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
- "├─SELU: 1-1 [-1, 1, 28, 952] --\n",
+ "├─Sequential: 1-1 [-1, 64, 7, 237] --\n",
+ "| └─Conv2d: 2-1 [-1, 64, 12, 474] 3,136\n",
+ "| └─BatchNorm2d: 2-2 [-1, 64, 12, 474] 128\n",
+ "├─SELU: 1-2 [-1, 64, 12, 474] --\n",
"├─Sequential: 1 [] --\n",
- "| └─SELU: 2-1 [-1, 1, 28, 952] --\n",
- "├─Sequential: 1-2 [-1, 512, 4, 119] --\n",
- "| └─Conv2d: 2-2 [-1, 64, 28, 952] 576\n",
- "| └─Sequential: 2-3 [-1, 64, 28, 952] --\n",
- "| | └─WideBlock: 3-1 [-1, 64, 28, 952] 73,984\n",
- "| └─Sequential: 2-4 [-1, 128, 14, 476] --\n",
- "| | └─WideBlock: 3-2 [-1, 128, 14, 476] 229,760\n",
- "| └─Sequential: 2-5 [-1, 256, 7, 238] --\n",
- "| | └─WideBlock: 3-3 [-1, 256, 7, 238] 918,272\n",
- "| └─Sequential: 2-6 [-1, 512, 4, 119] --\n",
- "| | └─WideBlock: 3-4 [-1, 512, 4, 119] 3,671,552\n",
- "├─Sequential: 1-3 [-1, 80] --\n",
- "| └─BatchNorm2d: 2-7 [-1, 512, 4, 119] 1,024\n",
- "├─SELU: 1-4 [-1, 512, 4, 119] --\n",
- "├─Sequential: 1 [] --\n",
- "| └─SELU: 2-8 [-1, 512, 4, 119] --\n",
- "| └─Reduce: 2-9 [-1, 512] --\n",
- "| └─Linear: 2-10 [-1, 80] 41,040\n",
+ "| └─SELU: 2-3 [-1, 64, 12, 474] --\n",
+ "| └─MaxPool2d: 2-4 [-1, 64, 7, 237] --\n",
+ "├─Sequential: 1-3 [-1, 256, 1, 30] --\n",
+ "| └─Sequential: 2-5 [-1, 64, 7, 237] --\n",
+ "| | └─WideBlock: 3-1 [-1, 64, 7, 237] 73,984\n",
+ "| └─Sequential: 2-6 [-1, 128, 4, 119] --\n",
+ "| | └─WideBlock: 3-2 [-1, 128, 4, 119] 229,760\n",
+ "| └─Sequential: 2-7 [-1, 256, 2, 60] --\n",
+ "| | └─WideBlock: 3-3 [-1, 256, 2, 60] 918,272\n",
+ "| └─Sequential: 2-8 [-1, 256, 1, 30] --\n",
+ "| | └─WideBlock: 3-4 [-1, 256, 1, 30] 1,246,208\n",
"==========================================================================================\n",
- "Total params: 4,936,208\n",
- "Trainable params: 4,936,208\n",
+ "Total params: 2,471,488\n",
+ "Trainable params: 2,471,488\n",
"Non-trainable params: 0\n",
- "Total mult-adds (M): 35.01\n",
+ "Total mult-adds (M): 27.71\n",
"==========================================================================================\n",
"Input size (MB): 0.10\n",
- "Forward/backward pass size (MB): 14.88\n",
- "Params size (MB): 18.83\n",
- "Estimated Total Size (MB): 33.81\n",
+ "Forward/backward pass size (MB): 5.55\n",
+ "Params size (MB): 9.43\n",
+ "Estimated Total Size (MB): 15.08\n",
"=========================================================================================="
]
},
- "execution_count": 26,
+ "execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "summary(backbone, (1, 28, 952), device=\"cpu\", depth=3)"
+ "summary(wr, (1, 28, 952), device=\"cpu\", depth=3)"
]
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
@@ -1131,16 +1435,89 @@
},
{
"cell_type": "code",
- "execution_count": 59,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
- "pred = torch.Tensor([1,1,1,1,1, 81, 1, 79, 79, 79, 2,1,1,1,1, 81, 1, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()\n",
+ "pred = torch.Tensor([1,21,2,45,31, 81, 1, 79, 79, 79, 2,1,1,1,1, 81, 1, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()\n",
"target = torch.Tensor([1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()"
]
},
{
"cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mask = (target != 79)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([ True, True, True, True, True, True, False, False, False, False,\n",
+ " True, True, True, True, True, True, False, False, False, False,\n",
+ " True, True, True, True, True, True, False, False, False, False])"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "mask"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([ 1, 21, 2, 45, 31, 81, 0, 0, 0, 0, 2, 1, 1, 1, 1, 81, 0, 0,\n",
+ " 0, 0, 1, 1, 1, 1, 1, 81, 0, 0, 0, 0])"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pred * mask"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([ 1, 1, 1, 1, 1, 81, 0, 0, 0, 0, 1, 1, 1, 1, 1, 81, 0, 0,\n",
+ " 0, 0, 1, 1, 1, 1, 1, 81, 0, 0, 0, 0])"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "target * mask"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],