diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-08 21:43:39 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-08 21:43:39 +0200 |
commit | 82f4acabe24e5171c40afa2939a4777ba87bcc30 (patch) | |
tree | 4d327fa26e4662a0447a66375442a9adeb13ea3d /notebooks/05c-test-model-end-to-end.ipynb | |
parent | 240f5e9f20032e82515fa66ce784619527d1041e (diff) |
Add training of VQGAN
Diffstat (limited to 'notebooks/05c-test-model-end-to-end.ipynb')
-rw-r--r-- | notebooks/05c-test-model-end-to-end.ipynb | 308 |
1 files changed, 169 insertions, 139 deletions
diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb index 42621da..23361b6 100644 --- a/notebooks/05c-test-model-end-to-end.ipynb +++ b/notebooks/05c-test-model-end-to-end.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 32, "id": "1e40a88b", "metadata": {}, "outputs": [], @@ -25,53 +25,7 @@ }, { "cell_type": "code", - "execution_count": 2, - "id": "f40fc669-829c-4de8-83ed-475fc6a0b8c1", - "metadata": {}, - "outputs": [], - "source": [ - "class T:\n", - " def __init__(self):\n", - " self.network = nn.Linear(1, 1)\n", - " \n", - " def get(self):\n", - " return getattr(self, \"network\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "d2bedf96-5388-4c7a-a048-1b97041cbedc", - "metadata": {}, - "outputs": [], - "source": [ - "t = T()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "a6fbe3be-2a9f-4050-a397-7ad982d6cd05", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "<generator object Module.parameters at 0x7f29ad6d6120>" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.get().parameters()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, + "execution_count": 33, "id": "38fb3d9d-a163-4b72-981f-f31b51be39f2", "metadata": {}, "outputs": [], @@ -83,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 47, "id": "74780b21-3313-452b-b580-703cac878416", "metadata": {}, "outputs": [ @@ -91,49 +45,181 @@ "name": "stdout", "output_type": "stream", "text": [ - "encoder:\n", - " _target_: text_recognizer.networks.vqvae.encoder.Encoder\n", - " in_channels: 1\n", - " hidden_dim: 32\n", - " channels_multipliers:\n", - " - 1\n", - " - 2\n", - " - 4\n", - " - 4\n", - " - 4\n", - " dropout_rate: 0.25\n", - "decoder:\n", - " _target_: text_recognizer.networks.vqvae.decoder.Decoder\n", - " out_channels: 1\n", - " hidden_dim: 32\n", - " channels_multipliers:\n", - " - 4\n", - " - 4\n", - " - 4\n", - " - 2\n", - " - 1\n", - " dropout_rate: 0.25\n", - "_target_: text_recognizer.networks.vqvae.vqvae.VQVAE\n", - "hidden_dim: 128\n", - "embedding_dim: 32\n", - "num_embeddings: 1024\n", - "decay: 0.99\n", + "callbacks:\n", + " model_checkpoint:\n", + " _target_: pytorch_lightning.callbacks.ModelCheckpoint\n", + " monitor: val/loss\n", + " save_top_k: 1\n", + " save_last: true\n", + " mode: min\n", + " verbose: false\n", + " dirpath: checkpoints/\n", + " filename: '{epoch:02d}'\n", + " learning_rate_monitor:\n", + " _target_: pytorch_lightning.callbacks.LearningRateMonitor\n", + " logging_interval: step\n", + " log_momentum: false\n", + " watch_model:\n", + " _target_: callbacks.wandb_callbacks.WatchModel\n", + " log: all\n", + " log_freq: 100\n", + " upload_ckpts_as_artifact:\n", + " _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n", + " ckpt_dir: checkpoints/\n", + " upload_best_only: true\n", + " log_image_reconstruction:\n", + " _target_: callbacks.wandb_callbacks.LogReconstuctedImages\n", + " num_samples: 8\n", + "criterion:\n", + " _target_: text_recognizer.criterions.vqgan_loss.VQGANLoss\n", + " reconstruction_loss:\n", + " _target_: torch.nn.L1Loss\n", + " reduction: mean\n", + " discriminator:\n", + " _target_: text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator\n", + " in_channels: 1\n", + " num_channels: 32\n", + " num_layers: 3\n", + " vq_loss_weight: 1.0\n", + " discriminator_weight: 1.0\n", + "datamodule:\n", + " _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs\n", + " batch_size: 8\n", + " num_workers: 12\n", + " train_fraction: 0.8\n", + " augment: true\n", + " pin_memory: false\n", + " word_pieces: true\n", + "logger:\n", + " wandb:\n", + " _target_: pytorch_lightning.loggers.wandb.WandbLogger\n", + " project: text-recognizer\n", + " name: null\n", + " save_dir: .\n", + " offline: false\n", + " id: null\n", + " log_model: false\n", + " prefix: ''\n", + " job_type: train\n", + " group: ''\n", + " tags: []\n", + "mapping:\n", + " _target_: text_recognizer.data.word_piece_mapping.WordPieceMapping\n", + " num_features: 1000\n", + " tokens: iamdb_1kwp_tokens_1000.txt\n", + " lexicon: iamdb_1kwp_lex_1000.txt\n", + " data_dir: null\n", + " use_words: false\n", + " prepend_wordsep: false\n", + " special_tokens:\n", + " - <s>\n", + " - <e>\n", + " - <p>\n", + " extra_symbols:\n", + " - '\n", "\n", - "{'encoder': {'_target_': 'text_recognizer.networks.vqvae.encoder.Encoder', 'in_channels': 1, 'hidden_dim': 32, 'channels_multipliers': [1, 2, 4, 4, 4], 'dropout_rate': 0.25}, 'decoder': {'_target_': 'text_recognizer.networks.vqvae.decoder.Decoder', 'out_channels': 1, 'hidden_dim': 32, 'channels_multipliers': [4, 4, 4, 2, 1], 'dropout_rate': 0.25}, '_target_': 'text_recognizer.networks.vqvae.vqvae.VQVAE', 'hidden_dim': 128, 'embedding_dim': 32, 'num_embeddings': 1024, 'decay': 0.99}\n" + " '\n", + "model:\n", + " _target_: text_recognizer.models.vqgan.VQGANLitModel\n", + "network:\n", + " encoder:\n", + " _target_: text_recognizer.networks.vqvae.encoder.Encoder\n", + " in_channels: 1\n", + " hidden_dim: 32\n", + " channels_multipliers:\n", + " - 1\n", + " - 2\n", + " - 4\n", + " - 8\n", + " - 8\n", + " dropout_rate: 0.25\n", + " decoder:\n", + " _target_: text_recognizer.networks.vqvae.decoder.Decoder\n", + " out_channels: 1\n", + " hidden_dim: 32\n", + " channels_multipliers:\n", + " - 8\n", + " - 8\n", + " - 4\n", + " - 1\n", + " dropout_rate: 0.25\n", + " _target_: text_recognizer.networks.vqvae.vqvae.VQVAE\n", + " hidden_dim: 256\n", + " embedding_dim: 32\n", + " num_embeddings: 1024\n", + " decay: 0.99\n", + "trainer:\n", + " _target_: pytorch_lightning.Trainer\n", + " stochastic_weight_avg: false\n", + " auto_scale_batch_size: binsearch\n", + " auto_lr_find: false\n", + " gradient_clip_val: 0\n", + " fast_dev_run: false\n", + " gpus: 1\n", + " precision: 16\n", + " max_epochs: 256\n", + " terminate_on_nan: true\n", + " weights_summary: top\n", + " limit_train_batches: 1.0\n", + " limit_val_batches: 1.0\n", + " limit_test_batches: 1.0\n", + " resume_from_checkpoint: null\n", + "seed: 4711\n", + "tune: false\n", + "train: true\n", + "test: true\n", + "logging: INFO\n", + "work_dir: ${hydra:runtime.cwd}\n", + "debug: false\n", + "print_config: false\n", + "ignore_warnings: true\n", + "summary: null\n", + "lr_schedulers:\n", + " generator:\n", + " _target_: torch.optim.lr_scheduler.CosineAnnealingLR\n", + " T_max: 256\n", + " eta_min: 0.0\n", + " last_epoch: -1\n", + " interval: epoch\n", + " monitor: val/loss\n", + " discriminator:\n", + " _target_: torch.optim.lr_scheduler.CosineAnnealingLR\n", + " T_max: 256\n", + " eta_min: 0.0\n", + " last_epoch: -1\n", + " interval: epoch\n", + " monitor: val/loss\n", + "optimizers:\n", + " generator:\n", + " _target_: madgrad.MADGRAD\n", + " lr: 0.001\n", + " momentum: 0.5\n", + " weight_decay: 0\n", + " eps: 1.0e-06\n", + " parameters: network\n", + " discriminator:\n", + " _target_: madgrad.MADGRAD\n", + " lr: 0.001\n", + " momentum: 0.5\n", + " weight_decay: 0\n", + " eps: 1.0e-06\n", + " parameters: loss_fn.discriminator\n", + "\n", + "{'callbacks': {'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': '{epoch:02d}'}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_image_reconstruction': {'_target_': 'callbacks.wandb_callbacks.LogReconstuctedImages', 'num_samples': 8}}, 'criterion': {'_target_': 'text_recognizer.criterions.vqgan_loss.VQGANLoss', 'reconstruction_loss': {'_target_': 'torch.nn.L1Loss', 'reduction': 'mean'}, 'discriminator': {'_target_': 'text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator', 'in_channels': 1, 'num_channels': 32, 'num_layers': 3}, 'vq_loss_weight': 1.0, 'discriminator_weight': 1.0}, 'datamodule': {'_target_': 'text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs', 'batch_size': 8, 'num_workers': 12, 'train_fraction': 0.8, 'augment': True, 'pin_memory': False, 'word_pieces': True}, 'logger': {'wandb': {'_target_': 'pytorch_lightning.loggers.wandb.WandbLogger', 'project': 'text-recognizer', 'name': None, 'save_dir': '.', 'offline': False, 'id': None, 'log_model': False, 'prefix': '', 'job_type': 'train', 'group': '', 'tags': []}}, 'mapping': {'_target_': 'text_recognizer.data.word_piece_mapping.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}, 'model': {'_target_': 'text_recognizer.models.vqgan.VQGANLitModel'}, 'network': {'encoder': {'_target_': 'text_recognizer.networks.vqvae.encoder.Encoder', 'in_channels': 1, 'hidden_dim': 32, 'channels_multipliers': [1, 2, 4, 8, 8], 'dropout_rate': 0.25}, 'decoder': {'_target_': 'text_recognizer.networks.vqvae.decoder.Decoder', 'out_channels': 1, 'hidden_dim': 32, 'channels_multipliers': [8, 8, 4, 1], 'dropout_rate': 0.25}, '_target_': 'text_recognizer.networks.vqvae.vqvae.VQVAE', 'hidden_dim': 256, 'embedding_dim': 32, 'num_embeddings': 1024, 'decay': 0.99}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'stochastic_weight_avg': False, 'auto_scale_batch_size': 'binsearch', 'auto_lr_find': False, 'gradient_clip_val': 0, 'fast_dev_run': False, 'gpus': 1, 'precision': 16, 'max_epochs': 256, 'terminate_on_nan': True, 'weights_summary': 'top', 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'limit_test_batches': 1.0, 'resume_from_checkpoint': None}, 'seed': 4711, 'tune': False, 'train': True, 'test': True, 'logging': 'INFO', 'work_dir': '${hydra:runtime.cwd}', 'debug': False, 'print_config': False, 'ignore_warnings': True, 'summary': None, 'lr_schedulers': {'generator': {'_target_': 'torch.optim.lr_scheduler.CosineAnnealingLR', 'T_max': 256, 'eta_min': 0.0, 'last_epoch': -1, 'interval': 'epoch', 'monitor': 'val/loss'}, 'discriminator': {'_target_': 'torch.optim.lr_scheduler.CosineAnnealingLR', 'T_max': 256, 'eta_min': 0.0, 'last_epoch': -1, 'interval': 'epoch', 'monitor': 'val/loss'}}, 'optimizers': {'generator': {'_target_': 'madgrad.MADGRAD', 'lr': 0.001, 'momentum': 0.5, 'weight_decay': 0, 'eps': 1e-06, 'parameters': 'network'}, 'discriminator': {'_target_': 'madgrad.MADGRAD', 'lr': 0.001, 'momentum': 0.5, 'weight_decay': 0, 'eps': 1e-06, 'parameters': 'loss_fn.discriminator'}}}\n" ] } ], "source": [ "# context initialization\n", - "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n", - " cfg = compose(config_name=\"vqvae\")\n", + "with initialize(config_path=\"../training/conf/\", job_name=\"test_app\"):\n", + " cfg = compose(config_name=\"config\", overrides=[\"+experiment=vqgan\"])\n", " print(OmegaConf.to_yaml(cfg))\n", " print(cfg)" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "205a03e8-7aa1-407f-afa5-92693715b677", "metadata": {}, "outputs": [], @@ -143,7 +229,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "c74384f0-754e-4c29-8f06-339372d6e4c1", "metadata": {}, "outputs": [], @@ -153,66 +239,10 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "5ebab599-2497-42f8-b54b-1663ee66fde9", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "├─Encoder: 1-1 [-1, 128, 18, 20] --\n", - "| └─Sequential: 2-1 [-1, 128, 18, 20] --\n", - "| | └─Conv2d: 3-1 [-1, 32, 576, 640] 320\n", - "| | └─Conv2d: 3-2 [-1, 32, 288, 320] 16,416\n", - "| | └─Mish: 3-3 [-1, 32, 288, 320] --\n", - "| | └─Conv2d: 3-4 [-1, 64, 144, 160] 32,832\n", - "| | └─Mish: 3-5 [-1, 64, 144, 160] --\n", - "| | └─Conv2d: 3-6 [-1, 128, 72, 80] 131,200\n", - "| | └─Mish: 3-7 [-1, 128, 72, 80] --\n", - "| | └─Conv2d: 3-8 [-1, 128, 36, 40] 262,272\n", - "| | └─Mish: 3-9 [-1, 128, 36, 40] --\n", - "| | └─Conv2d: 3-10 [-1, 128, 18, 20] 262,272\n", - "| | └─Mish: 3-11 [-1, 128, 18, 20] --\n", - "| | └─Residual: 3-12 [-1, 128, 18, 20] 164,352\n", - "| | └─Residual: 3-13 [-1, 128, 18, 20] 164,352\n", - "├─Conv2d: 1-2 [-1, 32, 18, 20] 4,128\n", - "├─VectorQuantizer: 1-3 [-1, 32, 18, 20] --\n", - "├─Conv2d: 1-4 [-1, 128, 18, 20] 4,224\n", - "├─Decoder: 1-5 [-1, 1, 576, 640] --\n", - "| └─Sequential: 2-2 [-1, 1, 576, 640] --\n", - "| | └─Residual: 3-14 [-1, 128, 18, 20] 164,352\n", - "| | └─Residual: 3-15 [-1, 128, 18, 20] 164,352\n", - "| | └─ConvTranspose2d: 3-16 [-1, 128, 36, 40] 262,272\n", - "| | └─Mish: 3-17 [-1, 128, 36, 40] --\n", - "| | └─ConvTranspose2d: 3-18 [-1, 128, 72, 80] 262,272\n", - "| | └─Mish: 3-19 [-1, 128, 72, 80] --\n", - "| | └─ConvTranspose2d: 3-20 [-1, 64, 144, 160] 131,136\n", - "| | └─Mish: 3-21 [-1, 64, 144, 160] --\n", - "| | └─ConvTranspose2d: 3-22 [-1, 32, 288, 320] 32,800\n", - "| | └─Mish: 3-23 [-1, 32, 288, 320] --\n", - "| | └─ConvTranspose2d: 3-24 [-1, 32, 576, 640] 16,416\n", - "| | └─Mish: 3-25 [-1, 32, 576, 640] --\n", - "| | └─Normalize: 3-26 [-1, 32, 576, 640] 64\n", - "| | └─Mish: 3-27 [-1, 32, 576, 640] --\n", - "| | └─Conv2d: 3-28 [-1, 1, 576, 640] 289\n", - "==========================================================================================\n", - "Total params: 2,076,321\n", - "Trainable params: 2,076,321\n", - "Non-trainable params: 0\n", - "Total mult-adds (G): 17.68\n", - "==========================================================================================\n", - "Input size (MB): 1.41\n", - "Forward/backward pass size (MB): 355.17\n", - "Params size (MB): 7.92\n", - "Estimated Total Size (MB): 364.49\n", - "==========================================================================================\n" - ] - } - ], + "outputs": [], "source": [ "summary(net, (1, 576, 640), device=\"cpu\");" ] |