summaryrefslogtreecommitdiff
path: root/src/notebooks
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-08-20 22:18:35 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-08-20 22:18:35 +0200
commit1f459ba19422593de325983040e176f97cf4ffc0 (patch)
tree89fef442d5dbe0c83253e9566d1762f0704f64e2 /src/notebooks
parent95cbdf5bc1cc9639febda23c28d8f464c998b214 (diff)
A lot of stuff working :D. ResNet implemented!
Diffstat (limited to 'src/notebooks')
-rw-r--r--src/notebooks/00-testing-stuff-out.ipynb639
-rw-r--r--src/notebooks/01-look-at-emnist.ipynb25
2 files changed, 639 insertions, 25 deletions
diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb
index 49ca4c4..3f008c3 100644
--- a/src/notebooks/00-testing-stuff-out.ipynb
+++ b/src/notebooks/00-testing-stuff-out.ipynb
@@ -2,11 +2,120 @@
"cells": [
{
"cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "%matplotlib inline\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "from PIL import Image\n",
+ "import torch\n",
+ "from importlib.util import find_spec\n",
+ "if find_spec(\"text_recognizer\") is None:\n",
+ " import sys\n",
+ " sys.path.append('..')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.cuda.is_available()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.nn.modules.activation.SELU"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.nn.SELU"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
- "import torch"
+ "a = \"nNone\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "b = a or \"relu\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'nnone'"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "b.lower()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'nNone'"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "b"
]
},
{
@@ -986,28 +1095,16 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 51,
"metadata": {},
- "outputs": [
- {
- "ename": "ModuleNotFoundError",
- "evalue": "No module named 'tqdm.auto.tqdm'; 'tqdm.auto' is not a package",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m<ipython-input-20-68e3c8bf3e1f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtqdm\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtqdm_auto\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'tqdm.auto.tqdm'; 'tqdm.auto' is not a package"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "import tqdm.auto.tqdm as tqdm_auto"
+ "import tqdm"
]
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 52,
"metadata": {},
"outputs": [
{
@@ -1016,7 +1113,7 @@
"tqdm.notebook.tqdm_notebook"
]
},
- "execution_count": 19,
+ "execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
@@ -1027,25 +1124,50 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tqdm.auto.tqdm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"def test():\n",
- " for i in range(9):\n",
+ " for i in tqdm.auto.tqdm(range(9)):\n",
" pass\n",
- " print(i)"
+ " print(i)\n",
+ " "
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 55,
"metadata": {},
"outputs": [
{
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "e1d3b25d4ee141e882e316ec54e79d60",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
"name": "stdout",
"output_type": "stream",
"text": [
+ "\n",
"8\n"
]
}
@@ -1056,6 +1178,479 @@
},
{
"cell_type": "code",
+ "execution_count": 58,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from time import sleep"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 71,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "41b743273ce14236bcb65782dbcd2e75",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "pbar = tqdm.auto.tqdm([\"a\", \"b\", \"c\", \"d\"], leave=True)\n",
+ "for char in pbar:\n",
+ " pbar.set_description(\"Processing %s\" % char)\n",
+ "# pbar.set_prefix()\n",
+ " sleep(0.25)\n",
+ "pbar.set_postfix({\"hej\": 0.32})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 83,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pbar.close()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 96,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "cb5ad8d6109f4b1495b8fc7422bafd01",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "with tqdm.auto.tqdm(total=10, bar_format=\"{postfix[0]} {postfix[1][value]:>8.2g}\",\n",
+ " postfix=[\"Batch\", dict(value=0)]) as t:\n",
+ " for i in range(10):\n",
+ " sleep(0.1)\n",
+ "# t.postfix[2][\"value\"] = 3 \n",
+ " t.postfix[1][\"value\"] = i / 2\n",
+ " t.update()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 99,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0b341d49ad074823881e84a538bcad0c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "with tqdm.auto.tqdm(total=100, leave=True) as pbar:\n",
+ " for i in range(2):\n",
+ " for i in range(10):\n",
+ " sleep(0.1)\n",
+ " pbar.update(10)\n",
+ " pbar.set_postfix({\"adaf\": 23})\n",
+ " pbar.set_postfix({\"hej\": 0.32})\n",
+ " pbar.reset()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, Encoder, ResidualNetwork"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "IdentityBlock(\n",
+ " (blocks): Identity()\n",
+ " (activation_fn): ReLU(inplace=True)\n",
+ " (shortcut): Identity()\n",
+ ")"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "IdentityBlock(32, 64)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "ResidualBlock(\n",
+ " (blocks): Identity()\n",
+ " (activation_fn): ReLU(inplace=True)\n",
+ " (shortcut): Sequential(\n",
+ " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ResidualBlock(32, 64)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BasicBlock(\n",
+ " (blocks): Sequential(\n",
+ " (0): Sequential(\n",
+ " (0): Conv2dAuto(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Sequential(\n",
+ " (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (activation_fn): ReLU(inplace=True)\n",
+ " (shortcut): Sequential(\n",
+ " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ ")\n"
+ ]
+ }
+ ],
+ "source": [
+ "dummy = torch.ones((1, 32, 224, 224))\n",
+ "\n",
+ "block = BasicBlock(32, 64)\n",
+ "block(dummy).shape\n",
+ "print(block)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BottleNeckBlock(\n",
+ " (blocks): Sequential(\n",
+ " (0): Sequential(\n",
+ " (0): Conv2dAuto(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Sequential(\n",
+ " (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " (3): ReLU(inplace=True)\n",
+ " (4): Sequential(\n",
+ " (0): Conv2dAuto(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (activation_fn): ReLU(inplace=True)\n",
+ " (shortcut): Sequential(\n",
+ " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ ")\n"
+ ]
+ }
+ ],
+ "source": [
+ "dummy = torch.ones((1, 32, 10, 10))\n",
+ "\n",
+ "block = BottleNeckBlock(32, 64)\n",
+ "block(dummy).shape\n",
+ "print(block)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 128, 24, 24])"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dummy = torch.ones((1, 64, 48, 48))\n",
+ "\n",
+ "layer = ResidualLayer(64, 128, block=BasicBlock, num_blocks=3)\n",
+ "layer(dummy).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(64, 128), (128, 256), (256, 512)]"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "blocks_sizes=[64, 128, 256, 512]\n",
+ "list(zip(blocks_sizes, blocks_sizes[1:]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "e = Encoder(depths=[1, 1])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from torchsummary import summary"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "----------------------------------------------------------------\n",
+ " Layer (type) Output Shape Param #\n",
+ "================================================================\n",
+ " Conv2d-1 [-1, 32, 15, 15] 800\n",
+ " BatchNorm2d-2 [-1, 32, 15, 15] 64\n",
+ " ReLU-3 [-1, 32, 15, 15] 0\n",
+ " MaxPool2d-4 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-5 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-6 [-1, 32, 8, 8] 64\n",
+ " ReLU-7 [-1, 32, 8, 8] 0\n",
+ " ReLU-8 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-9 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-10 [-1, 32, 8, 8] 64\n",
+ " ReLU-11 [-1, 32, 8, 8] 0\n",
+ " ReLU-12 [-1, 32, 8, 8] 0\n",
+ " BasicBlock-13 [-1, 32, 8, 8] 0\n",
+ " ResidualLayer-14 [-1, 32, 8, 8] 0\n",
+ " Conv2d-15 [-1, 64, 4, 4] 2,048\n",
+ " BatchNorm2d-16 [-1, 64, 4, 4] 128\n",
+ " Conv2dAuto-17 [-1, 64, 4, 4] 18,432\n",
+ " BatchNorm2d-18 [-1, 64, 4, 4] 128\n",
+ " ReLU-19 [-1, 64, 4, 4] 0\n",
+ " ReLU-20 [-1, 64, 4, 4] 0\n",
+ " Conv2dAuto-21 [-1, 64, 4, 4] 36,864\n",
+ " BatchNorm2d-22 [-1, 64, 4, 4] 128\n",
+ " ReLU-23 [-1, 64, 4, 4] 0\n",
+ " ReLU-24 [-1, 64, 4, 4] 0\n",
+ " BasicBlock-25 [-1, 64, 4, 4] 0\n",
+ " ResidualLayer-26 [-1, 64, 4, 4] 0\n",
+ "================================================================\n",
+ "Total params: 77,152\n",
+ "Trainable params: 77,152\n",
+ "Non-trainable params: 0\n",
+ "----------------------------------------------------------------\n",
+ "Input size (MB): 0.00\n",
+ "Forward/backward pass size (MB): 0.43\n",
+ "Params size (MB): 0.29\n",
+ "Estimated Total Size (MB): 0.73\n",
+ "----------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "summary(e, (1, 28, 28), device=\"cpu\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "resnet = ResidualNetwork(1, 80, activation=\"selu\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "----------------------------------------------------------------\n",
+ " Layer (type) Output Shape Param #\n",
+ "================================================================\n",
+ " Conv2d-1 [-1, 32, 15, 15] 800\n",
+ " BatchNorm2d-2 [-1, 32, 15, 15] 64\n",
+ " SELU-3 [-1, 32, 15, 15] 0\n",
+ " MaxPool2d-4 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-5 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-6 [-1, 32, 8, 8] 64\n",
+ " SELU-7 [-1, 32, 8, 8] 0\n",
+ " SELU-8 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-9 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-10 [-1, 32, 8, 8] 64\n",
+ " SELU-11 [-1, 32, 8, 8] 0\n",
+ " SELU-12 [-1, 32, 8, 8] 0\n",
+ " BasicBlock-13 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-14 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-15 [-1, 32, 8, 8] 64\n",
+ " SELU-16 [-1, 32, 8, 8] 0\n",
+ " SELU-17 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-18 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-19 [-1, 32, 8, 8] 64\n",
+ " SELU-20 [-1, 32, 8, 8] 0\n",
+ " SELU-21 [-1, 32, 8, 8] 0\n",
+ " BasicBlock-22 [-1, 32, 8, 8] 0\n",
+ " ResidualLayer-23 [-1, 32, 8, 8] 0\n",
+ " Conv2d-24 [-1, 64, 4, 4] 2,048\n",
+ " BatchNorm2d-25 [-1, 64, 4, 4] 128\n",
+ " Conv2dAuto-26 [-1, 64, 4, 4] 18,432\n",
+ " BatchNorm2d-27 [-1, 64, 4, 4] 128\n",
+ " SELU-28 [-1, 64, 4, 4] 0\n",
+ " SELU-29 [-1, 64, 4, 4] 0\n",
+ " Conv2dAuto-30 [-1, 64, 4, 4] 36,864\n",
+ " BatchNorm2d-31 [-1, 64, 4, 4] 128\n",
+ " SELU-32 [-1, 64, 4, 4] 0\n",
+ " SELU-33 [-1, 64, 4, 4] 0\n",
+ " BasicBlock-34 [-1, 64, 4, 4] 0\n",
+ " Conv2dAuto-35 [-1, 64, 4, 4] 36,864\n",
+ " BatchNorm2d-36 [-1, 64, 4, 4] 128\n",
+ " SELU-37 [-1, 64, 4, 4] 0\n",
+ " SELU-38 [-1, 64, 4, 4] 0\n",
+ " Conv2dAuto-39 [-1, 64, 4, 4] 36,864\n",
+ " BatchNorm2d-40 [-1, 64, 4, 4] 128\n",
+ " SELU-41 [-1, 64, 4, 4] 0\n",
+ " SELU-42 [-1, 64, 4, 4] 0\n",
+ " BasicBlock-43 [-1, 64, 4, 4] 0\n",
+ " ResidualLayer-44 [-1, 64, 4, 4] 0\n",
+ " Encoder-45 [-1, 64, 4, 4] 0\n",
+ " Reduce-46 [-1, 64] 0\n",
+ " Linear-47 [-1, 80] 5,200\n",
+ " Decoder-48 [-1, 80] 0\n",
+ "================================================================\n",
+ "Total params: 174,896\n",
+ "Trainable params: 174,896\n",
+ "Non-trainable params: 0\n",
+ "----------------------------------------------------------------\n",
+ "Input size (MB): 0.00\n",
+ "Forward/backward pass size (MB): 0.65\n",
+ "Params size (MB): 0.67\n",
+ "Estimated Total Size (MB): 1.32\n",
+ "----------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "summary(resnet, (1, 28, 28), device=\"cpu\")"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
diff --git a/src/notebooks/01-look-at-emnist.ipynb b/src/notebooks/01-look-at-emnist.ipynb
index a68b418..8648afb 100644
--- a/src/notebooks/01-look-at-emnist.ipynb
+++ b/src/notebooks/01-look-at-emnist.ipynb
@@ -31,12 +31,31 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
- "dataset = EmnistDataset()\n",
- "dataset.load_emnist_dataset()"
+ "dataset = EmnistDataset()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Tensor"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "type(dataset.data)"
]
},
{