summaryrefslogtreecommitdiff
path: root/src/notebooks/05a-UNet.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'src/notebooks/05a-UNet.ipynb')
-rw-r--r--src/notebooks/05a-UNet.ipynb413
1 files changed, 280 insertions, 133 deletions
diff --git a/src/notebooks/05a-UNet.ipynb b/src/notebooks/05a-UNet.ipynb
index c25865a..77d895d 100644
--- a/src/notebooks/05a-UNet.ipynb
+++ b/src/notebooks/05a-UNet.ipynb
@@ -23,76 +23,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "x = 64\n",
- "depth = 4\n",
- "channels = [x * 2 ** i for i in range(4)]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "channels.reverse()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[512, 256, 128, 64]"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "channels"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [],
- "source": [
- "m = nn.ModuleList([nn.Conv2d(1,3,2), nn.Linear(1, 5)])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
- {
- "ename": "ModuleAttributeError",
- "evalue": "'ModuleList' object has no attribute 'reverse'",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mModuleAttributeError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m<ipython-input-12-56d7987510bf>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreverse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;32m~/Library/Caches/pypoetry/virtualenvs/text-recognizer-cxOiES-R-py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 769\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 770\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 771\u001b[0;31m raise ModuleAttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m 772\u001b[0m type(self).__name__, name))\n\u001b[1;32m 773\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mModuleAttributeError\u001b[0m: 'ModuleList' object has no attribute 'reverse'"
- ]
- }
- ],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": 40,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -101,7 +32,7 @@
},
{
"cell_type": "code",
- "execution_count": 99,
+ "execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -110,7 +41,7 @@
},
{
"cell_type": "code",
- "execution_count": 100,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -119,72 +50,68 @@
},
{
"cell_type": "code",
- "execution_count": 101,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ModuleList(\n",
- " (0): DownSamplingBlock(\n",
- " (conv_block): ConvBlock(\n",
- " (activation): ReLU(inplace=True)\n",
- " (block): Sequential(\n",
- " (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
- " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (2): ReLU(inplace=True)\n",
- " (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
- " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (5): ReLU(inplace=True)\n",
- " )\n",
+ " (0): _DilationBlock(\n",
+ " (activation): ELU(alpha=1.0, inplace=True)\n",
+ " (conv): Sequential(\n",
+ " (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(6, 6), dilation=(3, 3))\n",
+ " (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
+ " (conv1): Sequential(\n",
+ " (0): Conv2d(1, 32, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): ELU(alpha=1.0, inplace=True)\n",
+ " )\n",
+ " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (down_sampling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
- " (1): DownSamplingBlock(\n",
- " (conv_block): ConvBlock(\n",
- " (activation): ReLU(inplace=True)\n",
- " (block): Sequential(\n",
- " (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
- " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (2): ReLU(inplace=True)\n",
- " (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
- " (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (5): ReLU(inplace=True)\n",
- " )\n",
+ " (1): _DilationBlock(\n",
+ " (activation): ELU(alpha=1.0, inplace=True)\n",
+ " (conv): Sequential(\n",
+ " (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(6, 6), dilation=(3, 3))\n",
+ " (1): ELU(alpha=1.0, inplace=True)\n",
+ " )\n",
+ " (conv1): Sequential(\n",
+ " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
+ " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (down_sampling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
- " (2): DownSamplingBlock(\n",
- " (conv_block): ConvBlock(\n",
- " (activation): ReLU(inplace=True)\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
- " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (2): ReLU(inplace=True)\n",
- " (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
- " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (5): ReLU(inplace=True)\n",
- " )\n",
+ " (2): _DilationBlock(\n",
+ " (activation): ELU(alpha=1.0, inplace=True)\n",
+ " (conv): Sequential(\n",
+ " (0): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(6, 6), dilation=(3, 3))\n",
+ " (1): ELU(alpha=1.0, inplace=True)\n",
+ " )\n",
+ " (conv1): Sequential(\n",
+ " (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (down_sampling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
- " (3): DownSamplingBlock(\n",
- " (conv_block): ConvBlock(\n",
- " (activation): ReLU(inplace=True)\n",
- " (block): Sequential(\n",
- " (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
- " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (2): ReLU(inplace=True)\n",
- " (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
- " (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (5): ReLU(inplace=True)\n",
- " )\n",
+ " (3): _DilationBlock(\n",
+ " (activation): ELU(alpha=1.0, inplace=True)\n",
+ " (conv): Sequential(\n",
+ " (0): Conv2d(256, 256, kernel_size=(5, 5), stride=(1, 1), padding=(6, 6), dilation=(3, 3))\n",
+ " (1): ELU(alpha=1.0, inplace=True)\n",
+ " )\n",
+ " (conv1): Sequential(\n",
+ " (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
")"
]
},
- "execution_count": 101,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -195,15 +122,15 @@
},
{
"cell_type": "code",
- "execution_count": 102,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ModuleList(\n",
- " (0): UpSamplingBlock(\n",
- " (conv_block): ConvBlock(\n",
+ " (0): _UpSamplingBlock(\n",
+ " (conv_block): _ConvBlock(\n",
" (activation): ReLU(inplace=True)\n",
" (block): Sequential(\n",
" (0): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
@@ -216,8 +143,8 @@
" )\n",
" (up_sampling): Upsample(scale_factor=2.0, mode=bilinear)\n",
" )\n",
- " (1): UpSamplingBlock(\n",
- " (conv_block): ConvBlock(\n",
+ " (1): _UpSamplingBlock(\n",
+ " (conv_block): _ConvBlock(\n",
" (activation): ReLU(inplace=True)\n",
" (block): Sequential(\n",
" (0): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
@@ -230,8 +157,8 @@
" )\n",
" (up_sampling): Upsample(scale_factor=2.0, mode=bilinear)\n",
" )\n",
- " (2): UpSamplingBlock(\n",
- " (conv_block): ConvBlock(\n",
+ " (2): _UpSamplingBlock(\n",
+ " (conv_block): _ConvBlock(\n",
" (activation): ReLU(inplace=True)\n",
" (block): Sequential(\n",
" (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
@@ -247,7 +174,7 @@
")"
]
},
- "execution_count": 102,
+ "execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@@ -258,7 +185,7 @@
},
{
"cell_type": "code",
- "execution_count": 104,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -267,7 +194,7 @@
"Conv2d(64, 3, kernel_size=(1, 1), stride=(1, 1))"
]
},
- "execution_count": 104,
+ "execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@@ -278,7 +205,25 @@
},
{
"cell_type": "code",
- "execution_count": 103,
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "yy = net(x)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "y = (torch.randn(1, 256, 256) > 0).long()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
@@ -287,21 +232,223 @@
"torch.Size([1, 3, 256, 256])"
]
},
- "execution_count": 103,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "net(x).shape"
+ "yy.shape"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 21,
"metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[1, 0, 1, ..., 0, 1, 0],\n",
+ " [1, 0, 1, ..., 0, 1, 0],\n",
+ " [1, 1, 0, ..., 1, 1, 0],\n",
+ " ...,\n",
+ " [1, 0, 0, ..., 0, 1, 1],\n",
+ " [0, 0, 1, ..., 1, 1, 0],\n",
+ " [0, 0, 1, ..., 0, 0, 0]]])"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
+ "metadata": {
+ "scrolled": true
+ },
"outputs": [],
- "source": []
+ "source": [
+ "loss = nn.CrossEntropyLoss()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 55,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(1.2502, grad_fn=<NllLoss2DBackward>)"
+ ]
+ },
+ "execution_count": 55,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "loss(yy, y)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[[-0.1692, 0.1223, 0.1750, ..., -0.1869, -0.0585, 0.0462],\n",
+ " [-0.1302, -0.0230, 0.3185, ..., -0.3760, 0.0204, -0.0686],\n",
+ " [-0.1062, -0.0216, 0.4592, ..., 0.0990, 0.0808, -0.1419],\n",
+ " ...,\n",
+ " [ 0.1386, -0.2856, 0.3074, ..., -0.3874, -0.0322, 0.0503],\n",
+ " [ 0.3562, -0.0960, 0.0815, ..., 0.1893, 0.1438, 0.2804],\n",
+ " [-0.2106, -0.1988, 0.0016, ..., -0.0031, -0.2820, 0.0113]],\n",
+ "\n",
+ " [[-0.1542, -0.1322, -0.3917, ..., -0.2297, -0.2328, 0.0103],\n",
+ " [ 0.1040, 0.2189, -0.3661, ..., 0.4818, -0.3737, 0.1117],\n",
+ " [ 0.0735, -0.6487, -0.1899, ..., 0.2213, -0.1529, -0.1020],\n",
+ " ...,\n",
+ " [-0.2046, -0.1477, 0.2941, ..., 0.0652, -0.7276, 0.1676],\n",
+ " [ 0.0413, -0.2013, -0.3192, ..., -0.4947, -0.1179, -0.1000],\n",
+ " [-0.4108, 0.0199, 0.2238, ..., -0.4482, -0.2370, 0.0119]],\n",
+ "\n",
+ " [[ 0.0834, 0.1303, 0.0629, ..., 0.4766, -0.0481, 0.2538],\n",
+ " [ 0.1218, 0.1324, 0.2464, ..., 0.0081, 0.4444, 0.4583],\n",
+ " [ 0.1155, 0.1417, 0.2248, ..., 0.6365, -0.0040, 0.3144],\n",
+ " ...,\n",
+ " [ 0.0744, -0.0751, -0.5654, ..., -0.2890, -0.0437, 0.2719],\n",
+ " [ 0.1057, -0.1093, -0.3803, ..., 0.0229, 0.1403, 0.0944],\n",
+ " [-0.0958, -0.3931, -0.0186, ..., 0.2102, -0.0842, 0.1909]]]],\n",
+ " grad_fn=<MkldnnConvolutionBackward>)"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "yy"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from torchsummary import summary"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "==========================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "==========================================================================================\n",
+ "├─ModuleList: 1 [] --\n",
+ "| └─DownSamplingBlock: 2-1 [-1, 64, 128, 128] --\n",
+ "| | └─ConvBlock: 3-1 [-1, 64, 256, 256] 37,824\n",
+ "| | └─MaxPool2d: 3-2 [-1, 64, 128, 128] --\n",
+ "| └─DownSamplingBlock: 2-2 [-1, 128, 64, 64] --\n",
+ "| | └─ConvBlock: 3-3 [-1, 128, 128, 128] 221,952\n",
+ "| | └─MaxPool2d: 3-4 [-1, 128, 64, 64] --\n",
+ "| └─DownSamplingBlock: 2-3 [-1, 256, 32, 32] --\n",
+ "| | └─ConvBlock: 3-5 [-1, 256, 64, 64] 886,272\n",
+ "| | └─MaxPool2d: 3-6 [-1, 256, 32, 32] --\n",
+ "| └─DownSamplingBlock: 2-4 [-1, 512, 32, 32] --\n",
+ "| | └─ConvBlock: 3-7 [-1, 512, 32, 32] 3,542,016\n",
+ "├─ModuleList: 1 [] --\n",
+ "| └─UpSamplingBlock: 2-5 [-1, 256, 64, 64] --\n",
+ "| | └─Upsample: 3-8 [-1, 512, 64, 64] --\n",
+ "| | └─ConvBlock: 3-9 [-1, 256, 64, 64] 2,360,832\n",
+ "| └─UpSamplingBlock: 2-6 [-1, 128, 128, 128] --\n",
+ "| | └─Upsample: 3-10 [-1, 256, 128, 128] --\n",
+ "| | └─ConvBlock: 3-11 [-1, 128, 128, 128] 590,592\n",
+ "| └─UpSamplingBlock: 2-7 [-1, 64, 256, 256] --\n",
+ "| | └─Upsample: 3-12 [-1, 128, 256, 256] --\n",
+ "| | └─ConvBlock: 3-13 [-1, 64, 256, 256] 147,840\n",
+ "├─Conv2d: 1-1 [-1, 3, 256, 256] 195\n",
+ "==========================================================================================\n",
+ "Total params: 7,787,523\n",
+ "Trainable params: 7,787,523\n",
+ "Non-trainable params: 0\n",
+ "Total mult-adds (M): 35.93\n",
+ "==========================================================================================\n",
+ "Input size (MB): 0.25\n",
+ "Forward/backward pass size (MB): 1.50\n",
+ "Params size (MB): 29.71\n",
+ "Estimated Total Size (MB): 31.46\n",
+ "==========================================================================================\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "==========================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "==========================================================================================\n",
+ "├─ModuleList: 1 [] --\n",
+ "| └─DownSamplingBlock: 2-1 [-1, 64, 128, 128] --\n",
+ "| | └─ConvBlock: 3-1 [-1, 64, 256, 256] 37,824\n",
+ "| | └─MaxPool2d: 3-2 [-1, 64, 128, 128] --\n",
+ "| └─DownSamplingBlock: 2-2 [-1, 128, 64, 64] --\n",
+ "| | └─ConvBlock: 3-3 [-1, 128, 128, 128] 221,952\n",
+ "| | └─MaxPool2d: 3-4 [-1, 128, 64, 64] --\n",
+ "| └─DownSamplingBlock: 2-3 [-1, 256, 32, 32] --\n",
+ "| | └─ConvBlock: 3-5 [-1, 256, 64, 64] 886,272\n",
+ "| | └─MaxPool2d: 3-6 [-1, 256, 32, 32] --\n",
+ "| └─DownSamplingBlock: 2-4 [-1, 512, 32, 32] --\n",
+ "| | └─ConvBlock: 3-7 [-1, 512, 32, 32] 3,542,016\n",
+ "├─ModuleList: 1 [] --\n",
+ "| └─UpSamplingBlock: 2-5 [-1, 256, 64, 64] --\n",
+ "| | └─Upsample: 3-8 [-1, 512, 64, 64] --\n",
+ "| | └─ConvBlock: 3-9 [-1, 256, 64, 64] 2,360,832\n",
+ "| └─UpSamplingBlock: 2-6 [-1, 128, 128, 128] --\n",
+ "| | └─Upsample: 3-10 [-1, 256, 128, 128] --\n",
+ "| | └─ConvBlock: 3-11 [-1, 128, 128, 128] 590,592\n",
+ "| └─UpSamplingBlock: 2-7 [-1, 64, 256, 256] --\n",
+ "| | └─Upsample: 3-12 [-1, 128, 256, 256] --\n",
+ "| | └─ConvBlock: 3-13 [-1, 64, 256, 256] 147,840\n",
+ "├─Conv2d: 1-1 [-1, 3, 256, 256] 195\n",
+ "==========================================================================================\n",
+ "Total params: 7,787,523\n",
+ "Trainable params: 7,787,523\n",
+ "Non-trainable params: 0\n",
+ "Total mult-adds (M): 35.93\n",
+ "==========================================================================================\n",
+ "Input size (MB): 0.25\n",
+ "Forward/backward pass size (MB): 1.50\n",
+ "Params size (MB): 29.71\n",
+ "Estimated Total Size (MB): 31.46\n",
+ "=========================================================================================="
+ ]
+ },
+ "execution_count": 47,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "\n",
+ "summary(net, (1, 256, 256), device=\"cpu\")"
+ ]
},
{
"cell_type": "code",