summaryrefslogtreecommitdiff
path: root/src/notebooks/05a-UNet.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'src/notebooks/05a-UNet.ipynb')
-rw-r--r--src/notebooks/05a-UNet.ipynb335
1 files changed, 335 insertions, 0 deletions
diff --git a/src/notebooks/05a-UNet.ipynb b/src/notebooks/05a-UNet.ipynb
new file mode 100644
index 0000000..c25865a
--- /dev/null
+++ b/src/notebooks/05a-UNet.ipynb
@@ -0,0 +1,335 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "%matplotlib inline\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "from PIL import Image\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "from importlib.util import find_spec\n",
+ "if find_spec(\"text_recognizer\") is None:\n",
+ " import sys\n",
+ " sys.path.append('..')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = 64\n",
+ "depth = 4\n",
+ "channels = [x * 2 ** i for i in range(4)]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "channels.reverse()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[512, 256, 128, 64]"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "channels"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "m = nn.ModuleList([nn.Conv2d(1,3,2), nn.Linear(1, 5)])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "ModuleAttributeError",
+ "evalue": "'ModuleList' object has no attribute 'reverse'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mModuleAttributeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m<ipython-input-12-56d7987510bf>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreverse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;32m~/Library/Caches/pypoetry/virtualenvs/text-recognizer-cxOiES-R-py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 769\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 770\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 771\u001b[0;31m raise ModuleAttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m 772\u001b[0m type(self).__name__, name))\n\u001b[1;32m 773\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mModuleAttributeError\u001b[0m: 'ModuleList' object has no attribute 'reverse'"
+ ]
+ }
+ ],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.unet import UNet"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 99,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "net = UNet()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 100,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = torch.rand(1, 1, 256, 256)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 101,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "ModuleList(\n",
+ " (0): DownSamplingBlock(\n",
+ " (conv_block): ConvBlock(\n",
+ " (activation): ReLU(inplace=True)\n",
+ " (block): Sequential(\n",
+ " (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (2): ReLU(inplace=True)\n",
+ " (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (5): ReLU(inplace=True)\n",
+ " )\n",
+ " )\n",
+ " (down_sampling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " )\n",
+ " (1): DownSamplingBlock(\n",
+ " (conv_block): ConvBlock(\n",
+ " (activation): ReLU(inplace=True)\n",
+ " (block): Sequential(\n",
+ " (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (2): ReLU(inplace=True)\n",
+ " (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (5): ReLU(inplace=True)\n",
+ " )\n",
+ " )\n",
+ " (down_sampling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " )\n",
+ " (2): DownSamplingBlock(\n",
+ " (conv_block): ConvBlock(\n",
+ " (activation): ReLU(inplace=True)\n",
+ " (block): Sequential(\n",
+ " (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (2): ReLU(inplace=True)\n",
+ " (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (5): ReLU(inplace=True)\n",
+ " )\n",
+ " )\n",
+ " (down_sampling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " )\n",
+ " (3): DownSamplingBlock(\n",
+ " (conv_block): ConvBlock(\n",
+ " (activation): ReLU(inplace=True)\n",
+ " (block): Sequential(\n",
+ " (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (2): ReLU(inplace=True)\n",
+ " (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (5): ReLU(inplace=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 101,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "net.encoder_blocks"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 102,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "ModuleList(\n",
+ " (0): UpSamplingBlock(\n",
+ " (conv_block): ConvBlock(\n",
+ " (activation): ReLU(inplace=True)\n",
+ " (block): Sequential(\n",
+ " (0): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (2): ReLU(inplace=True)\n",
+ " (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (5): ReLU(inplace=True)\n",
+ " )\n",
+ " )\n",
+ " (up_sampling): Upsample(scale_factor=2.0, mode=bilinear)\n",
+ " )\n",
+ " (1): UpSamplingBlock(\n",
+ " (conv_block): ConvBlock(\n",
+ " (activation): ReLU(inplace=True)\n",
+ " (block): Sequential(\n",
+ " (0): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (2): ReLU(inplace=True)\n",
+ " (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (5): ReLU(inplace=True)\n",
+ " )\n",
+ " )\n",
+ " (up_sampling): Upsample(scale_factor=2.0, mode=bilinear)\n",
+ " )\n",
+ " (2): UpSamplingBlock(\n",
+ " (conv_block): ConvBlock(\n",
+ " (activation): ReLU(inplace=True)\n",
+ " (block): Sequential(\n",
+ " (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (2): ReLU(inplace=True)\n",
+ " (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (5): ReLU(inplace=True)\n",
+ " )\n",
+ " )\n",
+ " (up_sampling): Upsample(scale_factor=2.0, mode=bilinear)\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 102,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "net.decoder_blocks"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 104,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Conv2d(64, 3, kernel_size=(1, 1), stride=(1, 1))"
+ ]
+ },
+ "execution_count": 104,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "net.head"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 103,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 3, 256, 256])"
+ ]
+ },
+ "execution_count": 103,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "net(x).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}