From e1b504bca41a9793ed7e88ef14f2e2cbd85724f2 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Tue, 8 Sep 2020 23:14:23 +0200 Subject: IAM datasets implemented. --- src/notebooks/00-testing-stuff-out.ipynb | 1660 +++--------------------------- 1 file changed, 133 insertions(+), 1527 deletions(-) (limited to 'src/notebooks/00-testing-stuff-out.ipynb') diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb index 3f008c3..ff9fb20 100644 --- a/src/notebooks/00-testing-stuff-out.ipynb +++ b/src/notebooks/00-testing-stuff-out.ipynb @@ -22,1104 +22,11 @@ }, { "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.cuda.is_available()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.nn.modules.activation.SELU" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.nn.SELU" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "a = \"nNone\"" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "b = a or \"relu\"" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'nnone'" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b.lower()" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'nNone'" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.optim.lr_scheduler.StepLR" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "getattr(torch.optim.lr_scheduler, \"StepLR\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a = getattr(torch.nn, \"ReLU\")()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "loss = getattr(torch.nn, \"L1Loss\")()" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "input = torch.randn(3, 5, requires_grad=True)\n", - "target = torch.randn(3, 5)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "b = torch.randn(2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "b" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a(b)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "output = loss(input, target)\n", - "output.backward()" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(1.1283)" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.tensor(output.item())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "s = 1.\n", - "if s is not None:\n", - " assert 0.0 < s < 1.0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class A:\n", - " @property\n", - " def __name__(self):\n", - " return \"adafa\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a = A()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a.__name__" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "from training.gpu_manager import GPUManager" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "gpu_manager = GPUManager(True)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-07-21 14:10:13.170 | DEBUG | training.gpu_manager:_get_free_gpu:57 - pid 11721 picking gpu 0\n" - ] - }, - { - "data": { - "text/plain": [ - "0" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "gpu_manager.get_free_gpu()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "p = Path(\"/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "str(p).split(\"/\")[0] + \"/\" + str(p).split(\"/\")[1]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "p.parents[0].resolve()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "p.exists()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "d = 'Experiment JSON, e.g. \\'{\"dataset\": \"EmnistDataset\", \"model\": \"CharacterModel\", \"network\": \"mlp\"}\\''" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(d)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "import yaml" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "path = \"/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/sample_experiment.yml\"" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "with open(path) as f:\n", - " d = yaml.safe_load(f)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "experiment_config = d[\"experiments\"][0]" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'dataloader': 'EmnistDataLoader',\n", - " 'data_loader_args': {'splits': ['train', 'val'],\n", - " 'sample_to_balance': True,\n", - " 'subsample_fraction': None,\n", - " 'transform': None,\n", - " 'target_transform': None,\n", - " 'batch_size': 256,\n", - " 'shuffle': True,\n", - " 'num_workers': 0,\n", - " 'cuda': True,\n", - " 'seed': 4711},\n", - " 'model': 'CharacterModel',\n", - " 'metrics': ['accuracy'],\n", - " 'network': 'MLP',\n", - " 'network_args': {'input_size': 784, 'num_layers': 2},\n", - " 'train_args': {'batch_size': 256, 'epochs': 16},\n", - " 'criterion': 'CrossEntropyLoss',\n", - " 'criterion_args': {'weight': None, 'ignore_index': -100, 'reduction': 'mean'},\n", - " 'optimizer': 'AdamW',\n", - " 'optimizer_args': {'lr': 0.0003,\n", - " 'betas': [0.9, 0.999],\n", - " 'eps': 1e-08,\n", - " 'weight_decay': 0,\n", - " 'amsgrad': False},\n", - " 'lr_scheduler': 'OneCycleLR',\n", - " 'lr_scheduler_args': {'max_lr': 3e-05, 'epochs': 16}}" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "experiment_config" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "import importlib" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "network_module = importlib.import_module(\"text_recognizer.networks\")\n", - "network_fn_ = getattr(network_module, experiment_config[\"network\"])\n", - "network_args = experiment_config.get(\"network_args\", {})" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(1, 784)" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "(1,) + (network_args[\"input_size\"],)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "optimizer_ = getattr(torch.optim, experiment_config[\"optimizer\"])\n", - "optimizer_args = experiment_config.get(\"optimizer_args\", {})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "optimizer_" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "optimizer_args" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "network_args" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "network_fn_" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "net = network_fn_(**network_args)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "optimizer_(net.parameters() , **optimizer_args)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "criterion_ = getattr(torch.nn, experiment_config[\"criterion\"])\n", - "criterion_args = experiment_config.get(\"criterion_args\", {})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "criterion_(**criterion_args)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "models_module = importlib.import_module(\"text_recognizer.models\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "metrics = {metric: getattr(models_module, metric) for metric in experiment_config[\"metrics\"]}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "torch.randn(3, 10)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "torch.randn(3, 1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "metrics['accuracy'](torch.randn(3, 10), torch.randn(3, 1))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "metric_fn_ = getattr(models_module, experiment_config[\"metric\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "metric_fn_" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "2.e-3" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "lr_scheduler_ = getattr(\n", - " torch.optim.lr_scheduler, experiment_config[\"lr_scheduler\"]\n", - ")\n", - "lr_scheduler_args = experiment_config.get(\"lr_scheduler_args\", {})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\"OneCycleLR\" in str(lr_scheduler_)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "datasets_module = importlib.import_module(\"text_recognizer.datasets\")\n", - "data_loader_ = getattr(datasets_module, experiment_config[\"dataloader\"])\n", - "data_loader_args = experiment_config.get(\"data_loader_args\", {})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data_loader_(**data_loader_args)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "cuda = \"cuda:0\"" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import re\n", - "cleanString = re.sub('[^A-Za-z]+','', cuda )" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "cleanString = re.sub('[^0-9]+','', cuda )" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'0'" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cleanString" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "([28, 28], 1)" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "([28, 28], ) + (1,)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "list(range(3-1))" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(1,)" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tuple([1])" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "from glob import glob" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt']" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "glob(\"/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/text_recognizer/weights/CharacterModel_*MLP_weights.pt\")" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "def test(a, b, c, d):\n", - " print(a,b,c,d)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "f = {\"a\": 2, \"b\": 3, \"c\": 4}" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "dict_items([('a', 2), ('b', 3), ('c', 4)])\n" - ] - } - ], - "source": [ - "print(f.items())" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2 3 4 1\n" - ] - } - ], - "source": [ - "test(**f, d=1)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "path = \"/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/*\"" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "l = glob(path)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "l.sort()" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "'/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_124928' in l" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_124928',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_141139',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_141213',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_141433',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_141702',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_145028',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_150212',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_150301',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_150317',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_151135',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_151408',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_153144',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_153207',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_153310',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_175150',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_180741',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_181933',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_183347',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_190044',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_190633',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_190738',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_191111',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_191310',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_191412',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_191504',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0721_191826',\n", - " '/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/CharacterModel_Emnist_MLP/0722_191559']" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "l" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "class ModeKeys:\n", - " \"\"\"Mode keys for CallbackList.\"\"\"\n", - "\n", - " TRAIN = \"train\"\n", - " VALIDATION = \"validation\"" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "m = ModeKeys()" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'train'" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "m.TRAIN" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([1.00000000e-05, 1.26485522e-05, 1.59985872e-05, 2.02358965e-05,\n", - " 2.55954792e-05, 3.23745754e-05, 4.09491506e-05, 5.17947468e-05,\n", - " 6.55128557e-05, 8.28642773e-05, 1.04811313e-04, 1.32571137e-04,\n", - " 1.67683294e-04, 2.12095089e-04, 2.68269580e-04, 3.39322177e-04,\n", - " 4.29193426e-04, 5.42867544e-04, 6.86648845e-04, 8.68511374e-04,\n", - " 1.09854114e-03, 1.38949549e-03, 1.75751062e-03, 2.22299648e-03,\n", - " 2.81176870e-03, 3.55648031e-03, 4.49843267e-03, 5.68986603e-03,\n", - " 7.19685673e-03, 9.10298178e-03, 1.15139540e-02, 1.45634848e-02,\n", - " 1.84206997e-02, 2.32995181e-02, 2.94705170e-02, 3.72759372e-02,\n", - " 4.71486636e-02, 5.96362332e-02, 7.54312006e-02, 9.54095476e-02,\n", - " 1.20679264e-01, 1.52641797e-01, 1.93069773e-01, 2.44205309e-01,\n", - " 3.08884360e-01, 3.90693994e-01, 4.94171336e-01, 6.25055193e-01,\n", - " 7.90604321e-01, 1.00000000e+00])" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.logspace(-5, 0, base=10)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0.018420699693267165" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.random.choice(np.logspace(-5, 0, base=10))" - ] - }, - { - "cell_type": "code", - "execution_count": 51, + "execution_count": 68, "metadata": {}, "outputs": [], "source": [ - "import tqdm" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tqdm.notebook.tqdm_notebook" - ] - }, - "execution_count": 52, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tqdm.auto.tqdm" + "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, Encoder, ResidualNetwork" ] }, { @@ -1127,274 +34,24 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "tqdm.auto.tqdm" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "metadata": {}, - "outputs": [], - "source": [ - "def test():\n", - " for i in tqdm.auto.tqdm(range(9)):\n", - " pass\n", - " print(i)\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e1d3b25d4ee141e882e316ec54e79d60", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "8\n" - ] - } - ], - "source": [ - "test()" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "metadata": {}, - "outputs": [], - "source": [ - "from time import sleep" - ] - }, - { - "cell_type": "code", - "execution_count": 71, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "41b743273ce14236bcb65782dbcd2e75", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "pbar = tqdm.auto.tqdm([\"a\", \"b\", \"c\", \"d\"], leave=True)\n", - "for char in pbar:\n", - " pbar.set_description(\"Processing %s\" % char)\n", - "# pbar.set_prefix()\n", - " sleep(0.25)\n", - "pbar.set_postfix({\"hej\": 0.32})" - ] - }, - { - "cell_type": "code", - "execution_count": 83, - "metadata": {}, - "outputs": [], - "source": [ - "pbar.close()" - ] - }, - { - "cell_type": "code", - "execution_count": 96, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "cb5ad8d6109f4b1495b8fc7422bafd01", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "with tqdm.auto.tqdm(total=10, bar_format=\"{postfix[0]} {postfix[1][value]:>8.2g}\",\n", - " postfix=[\"Batch\", dict(value=0)]) as t:\n", - " for i in range(10):\n", - " sleep(0.1)\n", - "# t.postfix[2][\"value\"] = 3 \n", - " t.postfix[1][\"value\"] = i / 2\n", - " t.update()" - ] - }, - { - "cell_type": "code", - "execution_count": 99, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0b341d49ad074823881e84a538bcad0c", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "with tqdm.auto.tqdm(total=100, leave=True) as pbar:\n", - " for i in range(2):\n", - " for i in range(10):\n", - " sleep(0.1)\n", - " pbar.update(10)\n", - " pbar.set_postfix({\"adaf\": 23})\n", - " pbar.set_postfix({\"hej\": 0.32})\n", - " pbar.reset()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, Encoder, ResidualNetwork" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "IdentityBlock(\n", - " (blocks): Identity()\n", - " (activation_fn): ReLU(inplace=True)\n", - " (shortcut): Identity()\n", - ")" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ "IdentityBlock(32, 64)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "ResidualBlock(\n", - " (blocks): Identity()\n", - " (activation_fn): ReLU(inplace=True)\n", - " (shortcut): Sequential(\n", - " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - ")" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "ResidualBlock(32, 64)" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "BasicBlock(\n", - " (blocks): Sequential(\n", - " (0): Sequential(\n", - " (0): Conv2dAuto(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (1): ReLU(inplace=True)\n", - " (2): Sequential(\n", - " (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (activation_fn): ReLU(inplace=True)\n", - " (shortcut): Sequential(\n", - " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - ")\n" - ] - } - ], + "outputs": [], "source": [ "dummy = torch.ones((1, 32, 224, 224))\n", "\n", @@ -1405,39 +62,9 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "BottleNeckBlock(\n", - " (blocks): Sequential(\n", - " (0): Sequential(\n", - " (0): Conv2dAuto(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (1): ReLU(inplace=True)\n", - " (2): Sequential(\n", - " (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (3): ReLU(inplace=True)\n", - " (4): Sequential(\n", - " (0): Conv2dAuto(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (activation_fn): ReLU(inplace=True)\n", - " (shortcut): Sequential(\n", - " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - ")\n" - ] - } - ], + "outputs": [], "source": [ "dummy = torch.ones((1, 32, 10, 10))\n", "\n", @@ -1448,20 +75,9 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 128, 24, 24])" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "dummy = torch.ones((1, 64, 48, 48))\n", "\n", @@ -1471,20 +87,9 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[(64, 128), (128, 256), (256, 512)]" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "blocks_sizes=[64, 128, 256, 512]\n", "list(zip(blocks_sizes, blocks_sizes[1:]))" @@ -1492,89 +97,110 @@ }, { "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "e = Encoder(depths=[1, 1])" - ] - }, - { - "cell_type": "code", - "execution_count": 19, + "execution_count": 69, "metadata": {}, "outputs": [], "source": [ - "from torchsummary import summary" + "e = Encoder(depths=[2, 1], block_sizes= [96, 128])" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 75, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "----------------------------------------------------------------\n", - " Layer (type) Output Shape Param #\n", - "================================================================\n", - " Conv2d-1 [-1, 32, 15, 15] 800\n", - " BatchNorm2d-2 [-1, 32, 15, 15] 64\n", - " ReLU-3 [-1, 32, 15, 15] 0\n", - " MaxPool2d-4 [-1, 32, 8, 8] 0\n", - " Conv2dAuto-5 [-1, 32, 8, 8] 9,216\n", - " BatchNorm2d-6 [-1, 32, 8, 8] 64\n", - " ReLU-7 [-1, 32, 8, 8] 0\n", - " ReLU-8 [-1, 32, 8, 8] 0\n", - " Conv2dAuto-9 [-1, 32, 8, 8] 9,216\n", - " BatchNorm2d-10 [-1, 32, 8, 8] 64\n", - " ReLU-11 [-1, 32, 8, 8] 0\n", - " ReLU-12 [-1, 32, 8, 8] 0\n", - " BasicBlock-13 [-1, 32, 8, 8] 0\n", - " ResidualLayer-14 [-1, 32, 8, 8] 0\n", - " Conv2d-15 [-1, 64, 4, 4] 2,048\n", - " BatchNorm2d-16 [-1, 64, 4, 4] 128\n", - " Conv2dAuto-17 [-1, 64, 4, 4] 18,432\n", - " BatchNorm2d-18 [-1, 64, 4, 4] 128\n", - " ReLU-19 [-1, 64, 4, 4] 0\n", - " ReLU-20 [-1, 64, 4, 4] 0\n", - " Conv2dAuto-21 [-1, 64, 4, 4] 36,864\n", - " BatchNorm2d-22 [-1, 64, 4, 4] 128\n", - " ReLU-23 [-1, 64, 4, 4] 0\n", - " ReLU-24 [-1, 64, 4, 4] 0\n", - " BasicBlock-25 [-1, 64, 4, 4] 0\n", - " ResidualLayer-26 [-1, 64, 4, 4] 0\n", - "================================================================\n", - "Total params: 77,152\n", - "Trainable params: 77,152\n", - "Non-trainable params: 0\n", - "----------------------------------------------------------------\n", - "Input size (MB): 0.00\n", - "Forward/backward pass size (MB): 0.43\n", - "Params size (MB): 0.29\n", - "Estimated Total Size (MB): 0.73\n", - "----------------------------------------------------------------\n" - ] + "data": { + "text/plain": [ + "Encoder(\n", + " (gate): Sequential(\n", + " (0): Conv2d(1, 96, kernel_size=(3, 3), stride=(2, 2), padding=(3, 3), bias=False)\n", + " (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", + " )\n", + " (blocks): Sequential(\n", + " (0): ResidualLayer(\n", + " (blocks): Sequential(\n", + " (0): BasicBlock(\n", + " (blocks): Sequential(\n", + " (0): Sequential(\n", + " (0): Conv2dAuto(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (1): ReLU(inplace=True)\n", + " (2): Sequential(\n", + " (0): Conv2dAuto(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (activation_fn): ReLU(inplace=True)\n", + " (shortcut): None\n", + " )\n", + " (1): BasicBlock(\n", + " (blocks): Sequential(\n", + " (0): Sequential(\n", + " (0): Conv2dAuto(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (1): ReLU(inplace=True)\n", + " (2): Sequential(\n", + " (0): Conv2dAuto(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (activation_fn): ReLU(inplace=True)\n", + " (shortcut): None\n", + " )\n", + " )\n", + " )\n", + " (1): ResidualLayer(\n", + " (blocks): Sequential(\n", + " (0): BasicBlock(\n", + " (blocks): Sequential(\n", + " (0): Sequential(\n", + " (0): Conv2dAuto(96, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (1): ReLU(inplace=True)\n", + " (2): Sequential(\n", + " (0): Conv2dAuto(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (activation_fn): ReLU(inplace=True)\n", + " (shortcut): Sequential(\n", + " (0): Conv2d(96, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", + " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 75, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "summary(e, (1, 28, 28), device=\"cpu\")" + "Encoder(**{\"depths\": [2, 1], \"block_sizes\": [96, 128]})" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 51, "metadata": {}, "outputs": [], "source": [ - "resnet = ResidualNetwork(1, 80, activation=\"selu\")" + "from torchsummary import summary" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 70, "metadata": {}, "outputs": [ { @@ -1584,77 +210,57 @@ "----------------------------------------------------------------\n", " Layer (type) Output Shape Param #\n", "================================================================\n", - " Conv2d-1 [-1, 32, 15, 15] 800\n", - " BatchNorm2d-2 [-1, 32, 15, 15] 64\n", - " SELU-3 [-1, 32, 15, 15] 0\n", - " MaxPool2d-4 [-1, 32, 8, 8] 0\n", - " Conv2dAuto-5 [-1, 32, 8, 8] 9,216\n", - " BatchNorm2d-6 [-1, 32, 8, 8] 64\n", - " SELU-7 [-1, 32, 8, 8] 0\n", - " SELU-8 [-1, 32, 8, 8] 0\n", - " Conv2dAuto-9 [-1, 32, 8, 8] 9,216\n", - " BatchNorm2d-10 [-1, 32, 8, 8] 64\n", - " SELU-11 [-1, 32, 8, 8] 0\n", - " SELU-12 [-1, 32, 8, 8] 0\n", - " BasicBlock-13 [-1, 32, 8, 8] 0\n", - " Conv2dAuto-14 [-1, 32, 8, 8] 9,216\n", - " BatchNorm2d-15 [-1, 32, 8, 8] 64\n", - " SELU-16 [-1, 32, 8, 8] 0\n", - " SELU-17 [-1, 32, 8, 8] 0\n", - " Conv2dAuto-18 [-1, 32, 8, 8] 9,216\n", - " BatchNorm2d-19 [-1, 32, 8, 8] 64\n", - " SELU-20 [-1, 32, 8, 8] 0\n", - " SELU-21 [-1, 32, 8, 8] 0\n", - " BasicBlock-22 [-1, 32, 8, 8] 0\n", - " ResidualLayer-23 [-1, 32, 8, 8] 0\n", - " Conv2d-24 [-1, 64, 4, 4] 2,048\n", - " BatchNorm2d-25 [-1, 64, 4, 4] 128\n", - " Conv2dAuto-26 [-1, 64, 4, 4] 18,432\n", - " BatchNorm2d-27 [-1, 64, 4, 4] 128\n", - " SELU-28 [-1, 64, 4, 4] 0\n", - " SELU-29 [-1, 64, 4, 4] 0\n", - " Conv2dAuto-30 [-1, 64, 4, 4] 36,864\n", - " BatchNorm2d-31 [-1, 64, 4, 4] 128\n", - " SELU-32 [-1, 64, 4, 4] 0\n", - " SELU-33 [-1, 64, 4, 4] 0\n", - " BasicBlock-34 [-1, 64, 4, 4] 0\n", - " Conv2dAuto-35 [-1, 64, 4, 4] 36,864\n", - " BatchNorm2d-36 [-1, 64, 4, 4] 128\n", - " SELU-37 [-1, 64, 4, 4] 0\n", - " SELU-38 [-1, 64, 4, 4] 0\n", - " Conv2dAuto-39 [-1, 64, 4, 4] 36,864\n", - " BatchNorm2d-40 [-1, 64, 4, 4] 128\n", - " SELU-41 [-1, 64, 4, 4] 0\n", - " SELU-42 [-1, 64, 4, 4] 0\n", - " BasicBlock-43 [-1, 64, 4, 4] 0\n", - " ResidualLayer-44 [-1, 64, 4, 4] 0\n", - " Encoder-45 [-1, 64, 4, 4] 0\n", - " Reduce-46 [-1, 64] 0\n", - " Linear-47 [-1, 80] 5,200\n", - " Decoder-48 [-1, 80] 0\n", + " Conv2d-1 [-1, 96, 16, 16] 864\n", + " BatchNorm2d-2 [-1, 96, 16, 16] 192\n", + " ReLU-3 [-1, 96, 16, 16] 0\n", + " MaxPool2d-4 [-1, 96, 8, 8] 0\n", + " Conv2dAuto-5 [-1, 96, 8, 8] 82,944\n", + " BatchNorm2d-6 [-1, 96, 8, 8] 192\n", + " ReLU-7 [-1, 96, 8, 8] 0\n", + " ReLU-8 [-1, 96, 8, 8] 0\n", + " Conv2dAuto-9 [-1, 96, 8, 8] 82,944\n", + " BatchNorm2d-10 [-1, 96, 8, 8] 192\n", + " ReLU-11 [-1, 96, 8, 8] 0\n", + " ReLU-12 [-1, 96, 8, 8] 0\n", + " BasicBlock-13 [-1, 96, 8, 8] 0\n", + " Conv2dAuto-14 [-1, 96, 8, 8] 82,944\n", + " BatchNorm2d-15 [-1, 96, 8, 8] 192\n", + " ReLU-16 [-1, 96, 8, 8] 0\n", + " ReLU-17 [-1, 96, 8, 8] 0\n", + " Conv2dAuto-18 [-1, 96, 8, 8] 82,944\n", + " BatchNorm2d-19 [-1, 96, 8, 8] 192\n", + " ReLU-20 [-1, 96, 8, 8] 0\n", + " ReLU-21 [-1, 96, 8, 8] 0\n", + " BasicBlock-22 [-1, 96, 8, 8] 0\n", + " ResidualLayer-23 [-1, 96, 8, 8] 0\n", + " Conv2d-24 [-1, 128, 4, 4] 12,288\n", + " BatchNorm2d-25 [-1, 128, 4, 4] 256\n", + " Conv2dAuto-26 [-1, 128, 4, 4] 110,592\n", + " BatchNorm2d-27 [-1, 128, 4, 4] 256\n", + " ReLU-28 [-1, 128, 4, 4] 0\n", + " ReLU-29 [-1, 128, 4, 4] 0\n", + " Conv2dAuto-30 [-1, 128, 4, 4] 147,456\n", + " BatchNorm2d-31 [-1, 128, 4, 4] 256\n", + " ReLU-32 [-1, 128, 4, 4] 0\n", + " ReLU-33 [-1, 128, 4, 4] 0\n", + " BasicBlock-34 [-1, 128, 4, 4] 0\n", + " ResidualLayer-35 [-1, 128, 4, 4] 0\n", "================================================================\n", - "Total params: 174,896\n", - "Trainable params: 174,896\n", + "Total params: 604,704\n", + "Trainable params: 604,704\n", "Non-trainable params: 0\n", "----------------------------------------------------------------\n", "Input size (MB): 0.00\n", - "Forward/backward pass size (MB): 0.65\n", - "Params size (MB): 0.67\n", - "Estimated Total Size (MB): 1.32\n", + "Forward/backward pass size (MB): 1.69\n", + "Params size (MB): 2.31\n", + "Estimated Total Size (MB): 4.00\n", "----------------------------------------------------------------\n" ] } ], "source": [ - "summary(resnet, (1, 28, 28), device=\"cpu\")" + "summary(e, (1, 28, 28), device=\"cpu\")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { -- cgit v1.2.3-70-g09d2