From 8c380f60a4f84f69ab4d2030cce663b4136fa0a7 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 25 Oct 2021 22:30:12 +0200 Subject: Remove vq and unet notebooks --- notebooks/04-vq-transformer.ipynb | 253 -------------------- notebooks/04-vqvae.ipynb | 233 ------------------ notebooks/05a-UNet.ipynb | 482 -------------------------------------- 3 files changed, 968 deletions(-) delete mode 100644 notebooks/04-vq-transformer.ipynb delete mode 100644 notebooks/04-vqvae.ipynb delete mode 100644 notebooks/05a-UNet.ipynb diff --git a/notebooks/04-vq-transformer.ipynb b/notebooks/04-vq-transformer.ipynb deleted file mode 100644 index 69d2688..0000000 --- a/notebooks/04-vq-transformer.ipynb +++ /dev/null @@ -1,253 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "7c02ae76-b540-4b16-9492-e9210b3b9249", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "os.environ['CUDA_VISIBLE_DEVICE'] = ''\n", - "import random\n", - "\n", - "%matplotlib inline\n", - "import matplotlib.pyplot as plt\n", - "\n", - "import numpy as np\n", - "from omegaconf import OmegaConf\n", - "\n", - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "from importlib.util import find_spec\n", - "if find_spec(\"text_recognizer\") is None:\n", - " import sys\n", - " sys.path.append('..')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ccdb6dde-47e5-429a-88f2-0764fb7e259a", - "metadata": {}, - "outputs": [], - "source": [ - "from hydra import compose, initialize\n", - "from omegaconf import OmegaConf\n", - "from hydra.utils import instantiate" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3cf50475-39f2-4642-a7d1-5bcbc0a036f7", - "metadata": {}, - "outputs": [], - "source": [ - "path = \"../training/conf/experiment/vqgan_htr_char_iam_lines.yaml\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e52ecb01-c975-4e55-925d-1182c7aea473", - "metadata": {}, - "outputs": [], - "source": [ - "with open(path, \"rb\") as f:\n", - " cfg = OmegaConf.load(f)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f939aa37-7b1d-45cc-885c-323c4540bda1", - "metadata": {}, - "outputs": [], - "source": [ - "cfg" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aaeab329-aeb0-4a1b-aa35-5a2aab81b1d0", - "metadata": {}, - "outputs": [], - "source": [ - "net = instantiate(cfg.network)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a564ac7a-b67f-4bc1-af36-0fe0a58c1bc9", - "metadata": {}, - "outputs": [], - "source": [ - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aeddcc5c-e48d-4d90-8efa-963011ef40bc", - "metadata": {}, - "outputs": [], - "source": [ - "x = torch.randn((16, 1, 16, 64))\n", - "y = torch.randint(0, 56, (16, 89))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0f0d78bc-7e0a-4d06-8e38-49b29ad25933", - "metadata": {}, - "outputs": [], - "source": [ - "y.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e9f4ee2a-c93f-4461-8d75-40c8c12d9d48", - "metadata": {}, - "outputs": [], - "source": [ - "yy = net(x, y)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7a7493a9-0e1d-46ef-8180-27605e18d082", - "metadata": {}, - "outputs": [], - "source": [ - "yy[0].shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "75bc9695-2afd-455c-a4fb-2e182456ccbd", - "metadata": {}, - "outputs": [], - "source": [ - "z = torch.randn((16, 8, 32))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3df6f9a0-6e66-4f46-a5b7-c0bb71b16b9b", - "metadata": {}, - "outputs": [], - "source": [ - "z, _ = net.encode(x)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6d6e9dd1-c56e-4169-8216-bcc84ea980e3", - "metadata": {}, - "outputs": [], - "source": [ - "z.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8f1539cb-b9b2-40b7-a843-d7479ddbddd7", - "metadata": {}, - "outputs": [], - "source": [ - "yy = net.decode(z, y[:, :2])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5cdba0a9-da7d-4e33-b209-7f360d1a38e5", - "metadata": {}, - "outputs": [], - "source": [ - "yy.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6da8065f-f93f-4aec-a60e-408712a28c3b", - "metadata": {}, - "outputs": [], - "source": [ - "torch.argmax(yy,dim=-2).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "beabbda7-6a1f-4294-8f01-f9d866ffe088", - "metadata": {}, - "outputs": [], - "source": [ - "yy[0].shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "618b997c-e6a6-4487-b70c-9d260cb556d3", - "metadata": {}, - "outputs": [], - "source": [ - "from torchinfo import summary" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "25759b7b-8deb-4163-b75d-a1357c9fe88f", - "metadata": {}, - "outputs": [], - "source": [ - "summary(net, (1, 1, 576, 640), device=\"cpu\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "62ca0d97-625c-474b-8d6c-d0caba79e198", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/04-vqvae.ipynb b/notebooks/04-vqvae.ipynb deleted file mode 100644 index 1b31671..0000000 --- a/notebooks/04-vqvae.ipynb +++ /dev/null @@ -1,233 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 7, - "id": "136a80f5-10e1-40c4-973a-a7eb7939bb1f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], - "source": [ - "import os\n", - "os.environ['CUDA_VISIBLE_DEVICE'] = ''\n", - "import random\n", - "\n", - "%matplotlib inline\n", - "import matplotlib.pyplot as plt\n", - "\n", - "import numpy as np\n", - "from omegaconf import OmegaConf\n", - "from hydra import compose, initialize\n", - "from omegaconf import OmegaConf\n", - "from hydra.utils import instantiate\n", - "from torchinfo import summary\n", - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "from importlib.util import find_spec\n", - "if find_spec(\"text_recognizer\") is None:\n", - " import sys\n", - " sys.path.append('..')" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "1a0fb9ca-1886-4fd4-839f-dc111a450cfd", - "metadata": {}, - "outputs": [], - "source": [ - "path = \"../training/conf/network/vqvae.yaml\"" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "0182a614-5781-44a6-b659-008e7c584fa7", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "encoder:\n", - " _target_: text_recognizer.networks.vqvae.encoder.Encoder\n", - " in_channels: 1\n", - " hidden_dim: 32\n", - " channels_multipliers:\n", - " - 1\n", - " - 2\n", - " - 4\n", - " dropout_rate: 0.0\n", - " activation: mish\n", - " use_norm: true\n", - " num_residuals: 4\n", - " residual_channels: 32\n", - "decoder:\n", - " _target_: text_recognizer.networks.vqvae.decoder.Decoder\n", - " out_channels: 1\n", - " hidden_dim: 32\n", - " channels_multipliers:\n", - " - 4\n", - " - 2\n", - " - 1\n", - " dropout_rate: 0.0\n", - " activation: mish\n", - " use_norm: true\n", - " num_residuals: 4\n", - " residual_channels: 32\n", - "_target_: text_recognizer.networks.vqvae.vqvae.VQVAE\n", - "hidden_dim: 128\n", - "embedding_dim: 32\n", - "num_embeddings: 8192\n", - "decay: 0.99\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/aktersnurra/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/hydra/_internal/defaults_list.py:251: UserWarning: In 'vqvae': Defaults list is missing `_self_`. See https://hydra.cc/docs/upgrades/1.0_to_1.1/default_composition_order for more information\n", - " warnings.warn(msg, UserWarning)\n" - ] - } - ], - "source": [ - "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n", - " cfg = compose(config_name=\"vqvae\")\n", - " print(OmegaConf.to_yaml(cfg))" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "a500f94c-7dae-477e-a3fb-2a2d62ee7b72", - "metadata": {}, - "outputs": [], - "source": [ - "net = instantiate(cfg)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "7f3b3559-5e23-485e-bf57-9405568a1fbf", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "====================================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "====================================================================================================\n", - "VQVAE -- --\n", - "├─Encoder: 1-1 [1, 128, 72, 80] --\n", - "│ └─Sequential: 2-1 [1, 128, 72, 80] --\n", - "│ │ └─Conv2d: 3-1 [1, 32, 576, 640] 320\n", - "│ │ └─Normalize: 3-2 [1, 32, 576, 640] 64\n", - "│ │ └─Mish: 3-3 [1, 32, 576, 640] --\n", - "│ │ └─Mish: 3-4 [1, 32, 576, 640] --\n", - "│ │ └─Mish: 3-5 [1, 32, 576, 640] --\n", - "│ │ └─Conv2d: 3-6 [1, 32, 288, 320] 16,416\n", - "│ │ └─Normalize: 3-7 [1, 32, 288, 320] 64\n", - "│ │ └─Mish: 3-8 [1, 32, 288, 320] --\n", - "│ │ └─Mish: 3-9 [1, 32, 288, 320] --\n", - "│ │ └─Mish: 3-10 [1, 32, 288, 320] --\n", - "│ │ └─Conv2d: 3-11 [1, 64, 144, 160] 32,832\n", - "│ │ └─Normalize: 3-12 [1, 64, 144, 160] 128\n", - "│ │ └─Mish: 3-13 [1, 64, 144, 160] --\n", - "│ │ └─Mish: 3-14 [1, 64, 144, 160] --\n", - "│ │ └─Mish: 3-15 [1, 64, 144, 160] --\n", - "│ │ └─Conv2d: 3-16 [1, 128, 72, 80] 131,200\n", - "│ │ └─Residual: 3-17 [1, 128, 72, 80] 41,280\n", - "│ │ └─Residual: 3-18 [1, 128, 72, 80] 41,280\n", - "│ │ └─Residual: 3-19 [1, 128, 72, 80] 41,280\n", - "│ │ └─Residual: 3-20 [1, 128, 72, 80] 41,280\n", - "├─Conv2d: 1-2 [1, 32, 72, 80] 4,128\n", - "├─VectorQuantizer: 1-3 [1, 32, 72, 80] --\n", - "├─Conv2d: 1-4 [1, 128, 72, 80] 4,224\n", - "├─Decoder: 1-5 [1, 1, 576, 640] --\n", - "│ └─Sequential: 2-2 [1, 1, 576, 640] --\n", - "│ │ └─Residual: 3-21 [1, 128, 72, 80] 41,280\n", - "│ │ └─Residual: 3-22 [1, 128, 72, 80] 41,280\n", - "│ │ └─Residual: 3-23 [1, 128, 72, 80] 41,280\n", - "│ │ └─Residual: 3-24 [1, 128, 72, 80] 41,280\n", - "│ │ └─Normalize: 3-25 [1, 128, 72, 80] 256\n", - "│ │ └─Mish: 3-26 [1, 128, 72, 80] --\n", - "│ │ └─Mish: 3-27 [1, 128, 72, 80] --\n", - "│ │ └─Mish: 3-28 [1, 128, 72, 80] --\n", - "│ │ └─ConvTranspose2d: 3-29 [1, 64, 144, 160] 131,136\n", - "│ │ └─Normalize: 3-30 [1, 64, 144, 160] 128\n", - "│ │ └─Mish: 3-31 [1, 64, 144, 160] --\n", - "│ │ └─Mish: 3-32 [1, 64, 144, 160] --\n", - "│ │ └─Mish: 3-33 [1, 64, 144, 160] --\n", - "│ │ └─ConvTranspose2d: 3-34 [1, 32, 288, 320] 32,800\n", - "│ │ └─Normalize: 3-35 [1, 32, 288, 320] 64\n", - "│ │ └─Mish: 3-36 [1, 32, 288, 320] --\n", - "│ │ └─Mish: 3-37 [1, 32, 288, 320] --\n", - "│ │ └─Mish: 3-38 [1, 32, 288, 320] --\n", - "│ │ └─ConvTranspose2d: 3-39 [1, 32, 576, 640] 16,416\n", - "│ │ └─Normalize: 3-40 [1, 32, 576, 640] 64\n", - "│ │ └─Conv2d: 3-41 [1, 1, 576, 640] 289\n", - "====================================================================================================\n", - "Total params: 700,769\n", - "Trainable params: 700,769\n", - "Non-trainable params: 0\n", - "Total mult-adds (G): 17.28\n", - "====================================================================================================\n", - "Input size (MB): 1.47\n", - "Forward/backward pass size (MB): 659.13\n", - "Params size (MB): 2.80\n", - "Estimated Total Size (MB): 663.41\n", - "====================================================================================================" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "summary(net, (1, 1, 576, 640), device=\"cpu\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9f880b03-d641-4640-acd3-aa5666ca5184", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/05a-UNet.ipynb b/notebooks/05a-UNet.ipynb deleted file mode 100644 index 3070e2d..0000000 --- a/notebooks/05a-UNet.ipynb +++ /dev/null @@ -1,482 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "%matplotlib inline\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "from PIL import Image\n", - "import torch\n", - "from torch import nn\n", - "from importlib.util import find_spec\n", - "if find_spec(\"text_recognizer\") is None:\n", - " import sys\n", - " sys.path.append('..')" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks.unet import UNet" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "net = UNet()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "x = torch.rand(1, 1, 256, 256)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "ModuleList(\n", - " (0): _DilationBlock(\n", - " (activation): ELU(alpha=1.0, inplace=True)\n", - " (conv): Sequential(\n", - " (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(6, 6), dilation=(3, 3))\n", - " (1): ELU(alpha=1.0, inplace=True)\n", - " )\n", - " (conv1): Sequential(\n", - " (0): Conv2d(1, 32, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): ELU(alpha=1.0, inplace=True)\n", - " )\n", - " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (down_sampling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", - " )\n", - " (1): _DilationBlock(\n", - " (activation): ELU(alpha=1.0, inplace=True)\n", - " (conv): Sequential(\n", - " (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(6, 6), dilation=(3, 3))\n", - " (1): ELU(alpha=1.0, inplace=True)\n", - " )\n", - " (conv1): Sequential(\n", - " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): ELU(alpha=1.0, inplace=True)\n", - " )\n", - " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (down_sampling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", - " )\n", - " (2): _DilationBlock(\n", - " (activation): ELU(alpha=1.0, inplace=True)\n", - " (conv): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(6, 6), dilation=(3, 3))\n", - " (1): ELU(alpha=1.0, inplace=True)\n", - " )\n", - " (conv1): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): ELU(alpha=1.0, inplace=True)\n", - " )\n", - " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (down_sampling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", - " )\n", - " (3): _DilationBlock(\n", - " (activation): ELU(alpha=1.0, inplace=True)\n", - " (conv): Sequential(\n", - " (0): Conv2d(256, 256, kernel_size=(5, 5), stride=(1, 1), padding=(6, 6), dilation=(3, 3))\n", - " (1): ELU(alpha=1.0, inplace=True)\n", - " )\n", - " (conv1): Sequential(\n", - " (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): ELU(alpha=1.0, inplace=True)\n", - " )\n", - " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - ")" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "net.encoder_blocks" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "ModuleList(\n", - " (0): _UpSamplingBlock(\n", - " (conv_block): _ConvBlock(\n", - " (activation): ReLU(inplace=True)\n", - " (block): Sequential(\n", - " (0): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " )\n", - " (up_sampling): Upsample(scale_factor=2.0, mode=bilinear)\n", - " )\n", - " (1): _UpSamplingBlock(\n", - " (conv_block): _ConvBlock(\n", - " (activation): ReLU(inplace=True)\n", - " (block): Sequential(\n", - " (0): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " )\n", - " (up_sampling): Upsample(scale_factor=2.0, mode=bilinear)\n", - " )\n", - " (2): _UpSamplingBlock(\n", - " (conv_block): _ConvBlock(\n", - " (activation): ReLU(inplace=True)\n", - " (block): Sequential(\n", - " (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " )\n", - " (up_sampling): Upsample(scale_factor=2.0, mode=bilinear)\n", - " )\n", - ")" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "net.decoder_blocks" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Conv2d(64, 3, kernel_size=(1, 1), stride=(1, 1))" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "net.head" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "yy = net(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "y = (torch.randn(1, 256, 256) > 0).long()" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 3, 256, 256])" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "yy.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[1, 0, 1, ..., 0, 1, 0],\n", - " [1, 0, 1, ..., 0, 1, 0],\n", - " [1, 1, 0, ..., 1, 1, 0],\n", - " ...,\n", - " [1, 0, 0, ..., 0, 1, 1],\n", - " [0, 0, 1, ..., 1, 1, 0],\n", - " [0, 0, 1, ..., 0, 0, 0]]])" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "y" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "loss = nn.CrossEntropyLoss()" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(1.2502, grad_fn=)" - ] - }, - "execution_count": 55, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "loss(yy, y)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[[-0.1692, 0.1223, 0.1750, ..., -0.1869, -0.0585, 0.0462],\n", - " [-0.1302, -0.0230, 0.3185, ..., -0.3760, 0.0204, -0.0686],\n", - " [-0.1062, -0.0216, 0.4592, ..., 0.0990, 0.0808, -0.1419],\n", - " ...,\n", - " [ 0.1386, -0.2856, 0.3074, ..., -0.3874, -0.0322, 0.0503],\n", - " [ 0.3562, -0.0960, 0.0815, ..., 0.1893, 0.1438, 0.2804],\n", - " [-0.2106, -0.1988, 0.0016, ..., -0.0031, -0.2820, 0.0113]],\n", - "\n", - " [[-0.1542, -0.1322, -0.3917, ..., -0.2297, -0.2328, 0.0103],\n", - " [ 0.1040, 0.2189, -0.3661, ..., 0.4818, -0.3737, 0.1117],\n", - " [ 0.0735, -0.6487, -0.1899, ..., 0.2213, -0.1529, -0.1020],\n", - " ...,\n", - " [-0.2046, -0.1477, 0.2941, ..., 0.0652, -0.7276, 0.1676],\n", - " [ 0.0413, -0.2013, -0.3192, ..., -0.4947, -0.1179, -0.1000],\n", - " [-0.4108, 0.0199, 0.2238, ..., -0.4482, -0.2370, 0.0119]],\n", - "\n", - " [[ 0.0834, 0.1303, 0.0629, ..., 0.4766, -0.0481, 0.2538],\n", - " [ 0.1218, 0.1324, 0.2464, ..., 0.0081, 0.4444, 0.4583],\n", - " [ 0.1155, 0.1417, 0.2248, ..., 0.6365, -0.0040, 0.3144],\n", - " ...,\n", - " [ 0.0744, -0.0751, -0.5654, ..., -0.2890, -0.0437, 0.2719],\n", - " [ 0.1057, -0.1093, -0.3803, ..., 0.0229, 0.1403, 0.0944],\n", - " [-0.0958, -0.3931, -0.0186, ..., 0.2102, -0.0842, 0.1909]]]],\n", - " grad_fn=)" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "yy" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [], - "source": [ - "from torchsummary import summary" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "├─ModuleList: 1 [] --\n", - "| └─DownSamplingBlock: 2-1 [-1, 64, 128, 128] --\n", - "| | └─ConvBlock: 3-1 [-1, 64, 256, 256] 37,824\n", - "| | └─MaxPool2d: 3-2 [-1, 64, 128, 128] --\n", - "| └─DownSamplingBlock: 2-2 [-1, 128, 64, 64] --\n", - "| | └─ConvBlock: 3-3 [-1, 128, 128, 128] 221,952\n", - "| | └─MaxPool2d: 3-4 [-1, 128, 64, 64] --\n", - "| └─DownSamplingBlock: 2-3 [-1, 256, 32, 32] --\n", - "| | └─ConvBlock: 3-5 [-1, 256, 64, 64] 886,272\n", - "| | └─MaxPool2d: 3-6 [-1, 256, 32, 32] --\n", - "| └─DownSamplingBlock: 2-4 [-1, 512, 32, 32] --\n", - "| | └─ConvBlock: 3-7 [-1, 512, 32, 32] 3,542,016\n", - "├─ModuleList: 1 [] --\n", - "| └─UpSamplingBlock: 2-5 [-1, 256, 64, 64] --\n", - "| | └─Upsample: 3-8 [-1, 512, 64, 64] --\n", - "| | └─ConvBlock: 3-9 [-1, 256, 64, 64] 2,360,832\n", - "| └─UpSamplingBlock: 2-6 [-1, 128, 128, 128] --\n", - "| | └─Upsample: 3-10 [-1, 256, 128, 128] --\n", - "| | └─ConvBlock: 3-11 [-1, 128, 128, 128] 590,592\n", - "| └─UpSamplingBlock: 2-7 [-1, 64, 256, 256] --\n", - "| | └─Upsample: 3-12 [-1, 128, 256, 256] --\n", - "| | └─ConvBlock: 3-13 [-1, 64, 256, 256] 147,840\n", - "├─Conv2d: 1-1 [-1, 3, 256, 256] 195\n", - "==========================================================================================\n", - "Total params: 7,787,523\n", - "Trainable params: 7,787,523\n", - "Non-trainable params: 0\n", - "Total mult-adds (M): 35.93\n", - "==========================================================================================\n", - "Input size (MB): 0.25\n", - "Forward/backward pass size (MB): 1.50\n", - "Params size (MB): 29.71\n", - "Estimated Total Size (MB): 31.46\n", - "==========================================================================================\n" - ] - }, - { - "data": { - "text/plain": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "├─ModuleList: 1 [] --\n", - "| └─DownSamplingBlock: 2-1 [-1, 64, 128, 128] --\n", - "| | └─ConvBlock: 3-1 [-1, 64, 256, 256] 37,824\n", - "| | └─MaxPool2d: 3-2 [-1, 64, 128, 128] --\n", - "| └─DownSamplingBlock: 2-2 [-1, 128, 64, 64] --\n", - "| | └─ConvBlock: 3-3 [-1, 128, 128, 128] 221,952\n", - "| | └─MaxPool2d: 3-4 [-1, 128, 64, 64] --\n", - "| └─DownSamplingBlock: 2-3 [-1, 256, 32, 32] --\n", - "| | └─ConvBlock: 3-5 [-1, 256, 64, 64] 886,272\n", - "| | └─MaxPool2d: 3-6 [-1, 256, 32, 32] --\n", - "| └─DownSamplingBlock: 2-4 [-1, 512, 32, 32] --\n", - "| | └─ConvBlock: 3-7 [-1, 512, 32, 32] 3,542,016\n", - "├─ModuleList: 1 [] --\n", - "| └─UpSamplingBlock: 2-5 [-1, 256, 64, 64] --\n", - "| | └─Upsample: 3-8 [-1, 512, 64, 64] --\n", - "| | └─ConvBlock: 3-9 [-1, 256, 64, 64] 2,360,832\n", - "| └─UpSamplingBlock: 2-6 [-1, 128, 128, 128] --\n", - "| | └─Upsample: 3-10 [-1, 256, 128, 128] --\n", - "| | └─ConvBlock: 3-11 [-1, 128, 128, 128] 590,592\n", - "| └─UpSamplingBlock: 2-7 [-1, 64, 256, 256] --\n", - "| | └─Upsample: 3-12 [-1, 128, 256, 256] --\n", - "| | └─ConvBlock: 3-13 [-1, 64, 256, 256] 147,840\n", - "├─Conv2d: 1-1 [-1, 3, 256, 256] 195\n", - "==========================================================================================\n", - "Total params: 7,787,523\n", - "Trainable params: 7,787,523\n", - "Non-trainable params: 0\n", - "Total mult-adds (M): 35.93\n", - "==========================================================================================\n", - "Input size (MB): 0.25\n", - "Forward/backward pass size (MB): 1.50\n", - "Params size (MB): 29.71\n", - "Estimated Total Size (MB): 31.46\n", - "==========================================================================================" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "summary(net, (1, 256, 256), device=\"cpu\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.6" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} -- cgit v1.2.3-70-g09d2