From 73ae250d7993fa48eccff4042ecd6bf768650bf3 Mon Sep 17 00:00:00 2001
From: aktersnurra <grydholm@kth.se>
Date: Wed, 18 Nov 2020 23:35:35 +0100
Subject: UNet implemented.

---
 src/notebooks/00-testing-stuff-out.ipynb |   2 +-
 src/notebooks/05a-UNet.ipynb             | 335 +++++++++++++++++++++++++++++++
 2 files changed, 336 insertions(+), 1 deletion(-)
 create mode 100644 src/notebooks/05a-UNet.ipynb

(limited to 'src/notebooks')

diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb
index 3686dcd..96a0c5a 100644
--- a/src/notebooks/00-testing-stuff-out.ipynb
+++ b/src/notebooks/00-testing-stuff-out.ipynb
@@ -1352,7 +1352,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.7.4"
+   "version": "3.8.2"
   }
  },
  "nbformat": 4,
diff --git a/src/notebooks/05a-UNet.ipynb b/src/notebooks/05a-UNet.ipynb
new file mode 100644
index 0000000..c25865a
--- /dev/null
+++ b/src/notebooks/05a-UNet.ipynb
@@ -0,0 +1,335 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%load_ext autoreload\n",
+    "%autoreload 2\n",
+    "\n",
+    "%matplotlib inline\n",
+    "import matplotlib.pyplot as plt\n",
+    "import numpy as np\n",
+    "from PIL import Image\n",
+    "import torch\n",
+    "from torch import nn\n",
+    "from importlib.util import find_spec\n",
+    "if find_spec(\"text_recognizer\") is None:\n",
+    "    import sys\n",
+    "    sys.path.append('..')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "x = 64\n",
+    "depth = 4\n",
+    "channels = [x * 2 ** i for i in range(4)]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "channels.reverse()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "[512, 256, 128, 64]"
+      ]
+     },
+     "execution_count": 9,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "channels"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "m = nn.ModuleList([nn.Conv2d(1,3,2), nn.Linear(1, 5)])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "ModuleAttributeError",
+     "evalue": "'ModuleList' object has no attribute 'reverse'",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mModuleAttributeError\u001b[0m                      Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-12-56d7987510bf>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreverse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;32m~/Library/Caches/pypoetry/virtualenvs/text-recognizer-cxOiES-R-py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m    769\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    770\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 771\u001b[0;31m         raise ModuleAttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m    772\u001b[0m             type(self).__name__, name))\n\u001b[1;32m    773\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mModuleAttributeError\u001b[0m: 'ModuleList' object has no attribute 'reverse'"
+     ]
+    }
+   ],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 40,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from text_recognizer.networks.unet import UNet"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 99,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "net = UNet()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 100,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "x = torch.rand(1, 1, 256, 256)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 101,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "ModuleList(\n",
+       "  (0): DownSamplingBlock(\n",
+       "    (conv_block): ConvBlock(\n",
+       "      (activation): ReLU(inplace=True)\n",
+       "      (block): Sequential(\n",
+       "        (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+       "        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+       "        (2): ReLU(inplace=True)\n",
+       "        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+       "        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+       "        (5): ReLU(inplace=True)\n",
+       "      )\n",
+       "    )\n",
+       "    (down_sampling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+       "  )\n",
+       "  (1): DownSamplingBlock(\n",
+       "    (conv_block): ConvBlock(\n",
+       "      (activation): ReLU(inplace=True)\n",
+       "      (block): Sequential(\n",
+       "        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+       "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+       "        (2): ReLU(inplace=True)\n",
+       "        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+       "        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+       "        (5): ReLU(inplace=True)\n",
+       "      )\n",
+       "    )\n",
+       "    (down_sampling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+       "  )\n",
+       "  (2): DownSamplingBlock(\n",
+       "    (conv_block): ConvBlock(\n",
+       "      (activation): ReLU(inplace=True)\n",
+       "      (block): Sequential(\n",
+       "        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+       "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+       "        (2): ReLU(inplace=True)\n",
+       "        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+       "        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+       "        (5): ReLU(inplace=True)\n",
+       "      )\n",
+       "    )\n",
+       "    (down_sampling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+       "  )\n",
+       "  (3): DownSamplingBlock(\n",
+       "    (conv_block): ConvBlock(\n",
+       "      (activation): ReLU(inplace=True)\n",
+       "      (block): Sequential(\n",
+       "        (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+       "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+       "        (2): ReLU(inplace=True)\n",
+       "        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+       "        (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+       "        (5): ReLU(inplace=True)\n",
+       "      )\n",
+       "    )\n",
+       "  )\n",
+       ")"
+      ]
+     },
+     "execution_count": 101,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "net.encoder_blocks"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 102,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "ModuleList(\n",
+       "  (0): UpSamplingBlock(\n",
+       "    (conv_block): ConvBlock(\n",
+       "      (activation): ReLU(inplace=True)\n",
+       "      (block): Sequential(\n",
+       "        (0): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+       "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+       "        (2): ReLU(inplace=True)\n",
+       "        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+       "        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+       "        (5): ReLU(inplace=True)\n",
+       "      )\n",
+       "    )\n",
+       "    (up_sampling): Upsample(scale_factor=2.0, mode=bilinear)\n",
+       "  )\n",
+       "  (1): UpSamplingBlock(\n",
+       "    (conv_block): ConvBlock(\n",
+       "      (activation): ReLU(inplace=True)\n",
+       "      (block): Sequential(\n",
+       "        (0): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+       "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+       "        (2): ReLU(inplace=True)\n",
+       "        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+       "        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+       "        (5): ReLU(inplace=True)\n",
+       "      )\n",
+       "    )\n",
+       "    (up_sampling): Upsample(scale_factor=2.0, mode=bilinear)\n",
+       "  )\n",
+       "  (2): UpSamplingBlock(\n",
+       "    (conv_block): ConvBlock(\n",
+       "      (activation): ReLU(inplace=True)\n",
+       "      (block): Sequential(\n",
+       "        (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+       "        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+       "        (2): ReLU(inplace=True)\n",
+       "        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+       "        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+       "        (5): ReLU(inplace=True)\n",
+       "      )\n",
+       "    )\n",
+       "    (up_sampling): Upsample(scale_factor=2.0, mode=bilinear)\n",
+       "  )\n",
+       ")"
+      ]
+     },
+     "execution_count": 102,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "net.decoder_blocks"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 104,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "Conv2d(64, 3, kernel_size=(1, 1), stride=(1, 1))"
+      ]
+     },
+     "execution_count": 104,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "net.head"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 103,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([1, 3, 256, 256])"
+      ]
+     },
+     "execution_count": 103,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "net(x).shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.2"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
-- 
cgit v1.2.3-70-g09d2