From 73ae250d7993fa48eccff4042ecd6bf768650bf3 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Wed, 18 Nov 2020 23:35:35 +0100 Subject: UNet implemented. --- src/notebooks/05a-UNet.ipynb | 335 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 335 insertions(+) create mode 100644 src/notebooks/05a-UNet.ipynb (limited to 'src/notebooks/05a-UNet.ipynb') diff --git a/src/notebooks/05a-UNet.ipynb b/src/notebooks/05a-UNet.ipynb new file mode 100644 index 0000000..c25865a --- /dev/null +++ b/src/notebooks/05a-UNet.ipynb @@ -0,0 +1,335 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from PIL import Image\n", + "import torch\n", + "from torch import nn\n", + "from importlib.util import find_spec\n", + "if find_spec(\"text_recognizer\") is None:\n", + " import sys\n", + " sys.path.append('..')" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "x = 64\n", + "depth = 4\n", + "channels = [x * 2 ** i for i in range(4)]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "channels.reverse()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[512, 256, 128, 64]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "channels" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "m = nn.ModuleList([nn.Conv2d(1,3,2), nn.Linear(1, 5)])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleAttributeError", + "evalue": "'ModuleList' object has no attribute 'reverse'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreverse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/Library/Caches/pypoetry/virtualenvs/text-recognizer-cxOiES-R-py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 769\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 770\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 771\u001b[0;31m raise ModuleAttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m 772\u001b[0m type(self).__name__, name))\n\u001b[1;32m 773\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleAttributeError\u001b[0m: 'ModuleList' object has no attribute 'reverse'" + ] + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.unet import UNet" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": {}, + "outputs": [], + "source": [ + "net = UNet()" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.rand(1, 1, 256, 256)" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ModuleList(\n", + " (0): DownSamplingBlock(\n", + " (conv_block): ConvBlock(\n", + " (activation): ReLU(inplace=True)\n", + " (block): Sequential(\n", + " (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (down_sampling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (1): DownSamplingBlock(\n", + " (conv_block): ConvBlock(\n", + " (activation): ReLU(inplace=True)\n", + " (block): Sequential(\n", + " (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (down_sampling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (2): DownSamplingBlock(\n", + " (conv_block): ConvBlock(\n", + " (activation): ReLU(inplace=True)\n", + " (block): Sequential(\n", + " (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (down_sampling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (3): DownSamplingBlock(\n", + " (conv_block): ConvBlock(\n", + " (activation): ReLU(inplace=True)\n", + " (block): Sequential(\n", + " (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 101, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "net.encoder_blocks" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ModuleList(\n", + " (0): UpSamplingBlock(\n", + " (conv_block): ConvBlock(\n", + " (activation): ReLU(inplace=True)\n", + " (block): Sequential(\n", + " (0): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (up_sampling): Upsample(scale_factor=2.0, mode=bilinear)\n", + " )\n", + " (1): UpSamplingBlock(\n", + " (conv_block): ConvBlock(\n", + " (activation): ReLU(inplace=True)\n", + " (block): Sequential(\n", + " (0): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (up_sampling): Upsample(scale_factor=2.0, mode=bilinear)\n", + " )\n", + " (2): UpSamplingBlock(\n", + " (conv_block): ConvBlock(\n", + " (activation): ReLU(inplace=True)\n", + " (block): Sequential(\n", + " (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (up_sampling): Upsample(scale_factor=2.0, mode=bilinear)\n", + " )\n", + ")" + ] + }, + "execution_count": 102, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "net.decoder_blocks" + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Conv2d(64, 3, kernel_size=(1, 1), stride=(1, 1))" + ] + }, + "execution_count": 104, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "net.head" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 3, 256, 256])" + ] + }, + "execution_count": 103, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "net(x).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.2" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} -- cgit v1.2.3-70-g09d2