diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-06 02:42:45 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-06 02:42:45 +0200 |
commit | 3ab82ad36bce6fa698a13a029a0694b75a5947b7 (patch) | |
tree | 136f71a62d60e3ccf01e1f95d64bb4d9f9c9befe /notebooks/05c-test-model-end-to-end.ipynb | |
parent | 1bccf71cf4eec335001b50a8fbc0c991d0e6d13a (diff) |
Fix VQVAE into en/decoder, bug in wandb artifact code uploading
Diffstat (limited to 'notebooks/05c-test-model-end-to-end.ipynb')
-rw-r--r-- | notebooks/05c-test-model-end-to-end.ipynb | 526 |
1 files changed, 306 insertions, 220 deletions
diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb index 913eafd..7996257 100644 --- a/notebooks/05c-test-model-end-to-end.ipynb +++ b/notebooks/05c-test-model-end-to-end.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "1e40a88b", "metadata": {}, "outputs": [], @@ -25,7 +25,294 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, + "id": "38fb3d9d-a163-4b72-981f-f31b51be39f2", + "metadata": {}, + "outputs": [], + "source": [ + "from hydra import compose, initialize\n", + "from omegaconf import OmegaConf\n", + "from hydra.utils import instantiate" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74780b21-3313-452b-b580-703cac878416", + "metadata": {}, + "outputs": [], + "source": [ + "# context initialization\n", + "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n", + " cfg = compose(config_name=\"vqvae\")\n", + " print(OmegaConf.to_yaml(cfg))\n", + " print(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "205a03e8-7aa1-407f-afa5-92693715b677", + "metadata": {}, + "outputs": [], + "source": [ + "net = instantiate(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c74384f0-754e-4c29-8f06-339372d6e4c1", + "metadata": {}, + "outputs": [], + "source": [ + "from torchsummary import summary" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ebab599-2497-42f8-b54b-1663ee66fde9", + "metadata": {}, + "outputs": [], + "source": [ + "summary(net, (1, 576, 640), device=\"cpu\");" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ba3f405-5948-465d-a7b8-459c84345034", + "metadata": {}, + "outputs": [], + "source": [ + "net = net.cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c998137-0967-488f-a572-a5f5a6b86353", + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.randn(16, 1, 576, 640)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "920aeeb2-088c-4ea0-84a2-a2532d4f697a", + "metadata": {}, + "outputs": [], + "source": [ + "x = x.cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "119ab631-fb3a-47a3-afc2-0e66260ebe7f", + "metadata": {}, + "outputs": [], + "source": [ + "xx, l = net(x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ccdec29-3952-460d-95b4-820b03aa4997", + "metadata": {}, + "outputs": [], + "source": [ + "xx.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a847084a-a65d-4072-ae1e-ae5d85a1664a", + "metadata": {}, + "outputs": [], + "source": [ + "l" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b21480a-707b-41de-b75d-30fb467973a4", + "metadata": {}, + "outputs": [], + "source": [ + "vq(x)[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cba1096d-8832-4955-88c9-a8650cf968cf", + "metadata": {}, + "outputs": [], + "source": [ + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "443a52d9-09f3-4e24-8a23-e0397a65f747", + "metadata": {}, + "outputs": [], + "source": [ + "import glob" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "78541477-6f02-42da-ad75-4a47bb043e79", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bdedced3-e08b-4bec-822c-e5dcd521c6b8", + "metadata": {}, + "outputs": [], + "source": [ + "list(Path(code_dir).glob(\"**/*.py\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79771541-c474-46a9-afdf-f74e736d6c16", + "metadata": {}, + "outputs": [], + "source": [ + "for path in glob.glob(os.path.join(code_dir, \"**/*.py\"), recursive=True):\n", + " print(path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a79a2a20-56df-48b3-b964-22a0def52117", + "metadata": {}, + "outputs": [], + "source": [ + "e = Encoder(1, 64, 32, 0.2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a6fd004-6d7c-4a20-9ed4-508a73b329b2", + "metadata": {}, + "outputs": [], + "source": [ + "d = Decoder(64, 1, 32, 0.2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "82c18401-ea33-4ab6-ace4-03cb6e2e4435", + "metadata": {}, + "outputs": [], + "source": [ + "z = e(x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64f99b20-fa37-4614-b258-5870b7668959", + "metadata": {}, + "outputs": [], + "source": [ + "xh = d(z)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a81e7de-1203-4ab6-9562-37341e135daf", + "metadata": {}, + "outputs": [], + "source": [ + "xh.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "204d167b-dce0-4dd7-b0e1-88a53859fd28", + "metadata": {}, + "outputs": [], + "source": [ + "a = [2, 2]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b77a6e8a-070d-46d3-9470-a5729eace57f", + "metadata": {}, + "outputs": [], + "source": [ + "a += [1, 1]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "741adac8-acc4-4715-afe9-07d3522cab62", + "metadata": {}, + "outputs": [], + "source": [ + "a" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49b894be-5947-4e06-b698-bb990bf2c64c", + "metadata": {}, + "outputs": [], + "source": [ + "x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4371af97-1f3b-4c5e-9812-3fb97d07c1cb", + "metadata": {}, + "outputs": [], + "source": [ + "576 // (2 * 4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28224cc8-79e0-481f-b24c-85bd0ef69f0a", + "metadata": {}, + "outputs": [], + "source": [ + "16 // 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, "id": "d3a6146b-94b1-4618-a4e4-00f8e23ffdb0", "metadata": {}, "outputs": [], @@ -37,148 +324,10 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "764c8736-7d68-4261-a57d-face10ebbf42", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "callbacks:\n", - " model_checkpoint:\n", - " _target_: pytorch_lightning.callbacks.ModelCheckpoint\n", - " monitor: val/loss\n", - " save_top_k: 1\n", - " save_last: true\n", - " mode: min\n", - " verbose: false\n", - " dirpath: checkpoints/\n", - " filename: '{epoch:02d}'\n", - " learning_rate_monitor:\n", - " _target_: pytorch_lightning.callbacks.LearningRateMonitor\n", - " logging_interval: step\n", - " log_momentum: false\n", - " watch_model:\n", - " _target_: callbacks.wandb_callbacks.WatchModel\n", - " log: all\n", - " log_freq: 100\n", - " upload_code_as_artifact:\n", - " _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact\n", - " project_dir: ${work_dir}/text_recognizer\n", - " upload_ckpts_as_artifact:\n", - " _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n", - " ckpt_dir: checkpoints/\n", - " upload_best_only: true\n", - " log_image_reconstruction:\n", - " _target_: callbacks.wandb_callbacks.LogReconstuctedImages\n", - " num_samples: 8\n", - "criterion:\n", - " _target_: torch.nn.MSELoss\n", - " reduction: mean\n", - "datamodule:\n", - " _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs\n", - " batch_size: 32\n", - " num_workers: 12\n", - " train_fraction: 0.8\n", - " augment: true\n", - " pin_memory: false\n", - " word_pieces: true\n", - "logger:\n", - " wandb:\n", - " _target_: pytorch_lightning.loggers.wandb.WandbLogger\n", - " project: text-recognizer\n", - " name: null\n", - " save_dir: .\n", - " offline: false\n", - " id: null\n", - " log_model: false\n", - " prefix: ''\n", - " job_type: train\n", - " group: ''\n", - " tags: []\n", - "lr_scheduler:\n", - " _target_: torch.optim.lr_scheduler.OneCycleLR\n", - " max_lr: 0.001\n", - " total_steps: null\n", - " epochs: 64\n", - " steps_per_epoch: 624\n", - " pct_start: 0.3\n", - " anneal_strategy: cos\n", - " cycle_momentum: true\n", - " base_momentum: 0.85\n", - " max_momentum: 0.95\n", - " div_factor: 25.0\n", - " final_div_factor: 10000.0\n", - " three_phase: true\n", - " last_epoch: -1\n", - " verbose: false\n", - "mapping:\n", - " _target_: text_recognizer.data.word_piece_mapping.WordPieceMapping\n", - " num_features: 1000\n", - " tokens: iamdb_1kwp_tokens_1000.txt\n", - " lexicon: iamdb_1kwp_lex_1000.txt\n", - " data_dir: null\n", - " use_words: false\n", - " prepend_wordsep: false\n", - " special_tokens:\n", - " - <s>\n", - " - <e>\n", - " - <p>\n", - " extra_symbols:\n", - " - '\n", - "\n", - " '\n", - "model:\n", - " _target_: text_recognizer.models.vqvae.VQVAELitModel\n", - " interval: step\n", - " monitor: val/loss\n", - " latent_loss_weight: 0.25\n", - "network:\n", - " _target_: text_recognizer.networks.vqvae.VQVAE\n", - " in_channels: 1\n", - " res_channels: 32\n", - " num_residual_layers: 2\n", - " embedding_dim: 64\n", - " num_embeddings: 512\n", - " decay: 0.99\n", - " activation: mish\n", - "optimizer:\n", - " _target_: madgrad.MADGRAD\n", - " lr: 0.01\n", - " momentum: 0.9\n", - " weight_decay: 0\n", - " eps: 1.0e-06\n", - "trainer:\n", - " _target_: pytorch_lightning.Trainer\n", - " stochastic_weight_avg: false\n", - " auto_scale_batch_size: binsearch\n", - " auto_lr_find: false\n", - " gradient_clip_val: 0\n", - " fast_dev_run: false\n", - " gpus: 1\n", - " precision: 16\n", - " max_epochs: 64\n", - " terminate_on_nan: true\n", - " weights_summary: top\n", - " limit_train_batches: 1.0\n", - " limit_val_batches: 1.0\n", - " limit_test_batches: 1.0\n", - " resume_from_checkpoint: null\n", - "seed: 4711\n", - "tune: false\n", - "train: true\n", - "test: true\n", - "logging: INFO\n", - "work_dir: ${hydra:runtime.cwd}\n", - "debug: false\n", - "print_config: true\n", - "ignore_warnings: true\n", - "\n", - "{'callbacks': {'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': '{epoch:02d}'}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_code_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCodeAsArtifact', 'project_dir': '${work_dir}/text_recognizer'}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_image_reconstruction': {'_target_': 'callbacks.wandb_callbacks.LogReconstuctedImages', 'num_samples': 8}}, 'criterion': {'_target_': 'torch.nn.MSELoss', 'reduction': 'mean'}, 'datamodule': {'_target_': 'text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs', 'batch_size': 32, 'num_workers': 12, 'train_fraction': 0.8, 'augment': True, 'pin_memory': False, 'word_pieces': True}, 'logger': {'wandb': {'_target_': 'pytorch_lightning.loggers.wandb.WandbLogger', 'project': 'text-recognizer', 'name': None, 'save_dir': '.', 'offline': False, 'id': None, 'log_model': False, 'prefix': '', 'job_type': 'train', 'group': '', 'tags': []}}, 'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.OneCycleLR', 'max_lr': 0.001, 'total_steps': None, 'epochs': 64, 'steps_per_epoch': 624, 'pct_start': 0.3, 'anneal_strategy': 'cos', 'cycle_momentum': True, 'base_momentum': 0.85, 'max_momentum': 0.95, 'div_factor': 25.0, 'final_div_factor': 10000.0, 'three_phase': True, 'last_epoch': -1, 'verbose': False}, 'mapping': {'_target_': 'text_recognizer.data.word_piece_mapping.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}, 'model': {'_target_': 'text_recognizer.models.vqvae.VQVAELitModel', 'interval': 'step', 'monitor': 'val/loss', 'latent_loss_weight': 0.25}, 'network': {'_target_': 'text_recognizer.networks.vqvae.VQVAE', 'in_channels': 1, 'res_channels': 32, 'num_residual_layers': 2, 'embedding_dim': 64, 'num_embeddings': 512, 'decay': 0.99, 'activation': 'mish'}, 'optimizer': {'_target_': 'madgrad.MADGRAD', 'lr': 0.01, 'momentum': 0.9, 'weight_decay': 0, 'eps': 1e-06}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'stochastic_weight_avg': False, 'auto_scale_batch_size': 'binsearch', 'auto_lr_find': False, 'gradient_clip_val': 0, 'fast_dev_run': False, 'gpus': 1, 'precision': 16, 'max_epochs': 64, 'terminate_on_nan': True, 'weights_summary': 'top', 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'limit_test_batches': 1.0, 'resume_from_checkpoint': None}, 'seed': 4711, 'tune': False, 'train': True, 'test': True, 'logging': 'INFO', 'work_dir': '${hydra:runtime.cwd}', 'debug': False, 'print_config': True, 'ignore_warnings': True}\n" - ] - } - ], + "outputs": [], "source": [ "# context initialization\n", "with initialize(config_path=\"../training/conf/\", job_name=\"test_app\"):\n", @@ -189,25 +338,17 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "c1a9aa6b-6405-4ffe-b065-02340762476a", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2021-08-04 05:07:26.480 | DEBUG | text_recognizer.data.word_piece_mapping:__init__:37 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n" - ] - } - ], + "outputs": [], "source": [ "mapping = instantiate(cfg.mapping)" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "969ba3be-d78f-4b1e-b522-ea8a42669e86", "metadata": {}, "outputs": [], @@ -217,7 +358,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "6147cd3e-0ad1-490f-917d-21be9bb8ce1c", "metadata": {}, "outputs": [], @@ -227,70 +368,37 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "a0ecea0c-abaf-4d5d-a13d-c085c1e4d282", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 64, 144, 160])" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "network.encode(x)[0].shape" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "a7b9f249-7e5e-4f31-bbe1-cfd6d3701cf0", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([512])\n", - "torch.Size([512])\n", - "torch.Size([512])\n", - "torch.Size([512])\n" - ] - } - ], + "outputs": [], "source": [ "t, l = network(x)" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "9a9450d2-f45d-4823-adac-68a8ea05ed1d", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(0.0188, grad_fn=<AddBackward0>)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "l" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "93b8c90f-788a-4095-aa7a-55b34f0ddaaf", "metadata": {}, "outputs": [], @@ -300,42 +408,20 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "c9983788-2dae-4375-a821-a64cd1c68edf", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(0.5669, grad_fn=<AddBackward0>)" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "F.mse_loss(x, t) + l" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "29b128ca-80b7-481e-bb3c-44f109c7d292", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 1, 576, 640])" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "t.shape" ] |