summaryrefslogtreecommitdiff
path: root/notebooks/05c-test-model-end-to-end.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks/05c-test-model-end-to-end.ipynb')
-rw-r--r--notebooks/05c-test-model-end-to-end.ipynb898
1 files changed, 0 insertions, 898 deletions
diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb
deleted file mode 100644
index 23361b6..0000000
--- a/notebooks/05c-test-model-end-to-end.ipynb
+++ /dev/null
@@ -1,898 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 32,
- "id": "1e40a88b",
- "metadata": {},
- "outputs": [],
- "source": [
- "%load_ext autoreload\n",
- "%autoreload 2\n",
- "\n",
- "%matplotlib inline\n",
- "import matplotlib.pyplot as plt\n",
- "import numpy as np\n",
- "from PIL import Image\n",
- "import torch\n",
- "from torch import nn\n",
- "from importlib.util import find_spec\n",
- "if find_spec(\"text_recognizer\") is None:\n",
- " import sys\n",
- " sys.path.append('..')\n",
- " "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "id": "38fb3d9d-a163-4b72-981f-f31b51be39f2",
- "metadata": {},
- "outputs": [],
- "source": [
- "from hydra import compose, initialize\n",
- "from omegaconf import OmegaConf\n",
- "from hydra.utils import instantiate"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 47,
- "id": "74780b21-3313-452b-b580-703cac878416",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "callbacks:\n",
- " model_checkpoint:\n",
- " _target_: pytorch_lightning.callbacks.ModelCheckpoint\n",
- " monitor: val/loss\n",
- " save_top_k: 1\n",
- " save_last: true\n",
- " mode: min\n",
- " verbose: false\n",
- " dirpath: checkpoints/\n",
- " filename: '{epoch:02d}'\n",
- " learning_rate_monitor:\n",
- " _target_: pytorch_lightning.callbacks.LearningRateMonitor\n",
- " logging_interval: step\n",
- " log_momentum: false\n",
- " watch_model:\n",
- " _target_: callbacks.wandb_callbacks.WatchModel\n",
- " log: all\n",
- " log_freq: 100\n",
- " upload_ckpts_as_artifact:\n",
- " _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n",
- " ckpt_dir: checkpoints/\n",
- " upload_best_only: true\n",
- " log_image_reconstruction:\n",
- " _target_: callbacks.wandb_callbacks.LogReconstuctedImages\n",
- " num_samples: 8\n",
- "criterion:\n",
- " _target_: text_recognizer.criterions.vqgan_loss.VQGANLoss\n",
- " reconstruction_loss:\n",
- " _target_: torch.nn.L1Loss\n",
- " reduction: mean\n",
- " discriminator:\n",
- " _target_: text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator\n",
- " in_channels: 1\n",
- " num_channels: 32\n",
- " num_layers: 3\n",
- " vq_loss_weight: 1.0\n",
- " discriminator_weight: 1.0\n",
- "datamodule:\n",
- " _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs\n",
- " batch_size: 8\n",
- " num_workers: 12\n",
- " train_fraction: 0.8\n",
- " augment: true\n",
- " pin_memory: false\n",
- " word_pieces: true\n",
- "logger:\n",
- " wandb:\n",
- " _target_: pytorch_lightning.loggers.wandb.WandbLogger\n",
- " project: text-recognizer\n",
- " name: null\n",
- " save_dir: .\n",
- " offline: false\n",
- " id: null\n",
- " log_model: false\n",
- " prefix: ''\n",
- " job_type: train\n",
- " group: ''\n",
- " tags: []\n",
- "mapping:\n",
- " _target_: text_recognizer.data.word_piece_mapping.WordPieceMapping\n",
- " num_features: 1000\n",
- " tokens: iamdb_1kwp_tokens_1000.txt\n",
- " lexicon: iamdb_1kwp_lex_1000.txt\n",
- " data_dir: null\n",
- " use_words: false\n",
- " prepend_wordsep: false\n",
- " special_tokens:\n",
- " - <s>\n",
- " - <e>\n",
- " - <p>\n",
- " extra_symbols:\n",
- " - '\n",
- "\n",
- " '\n",
- "model:\n",
- " _target_: text_recognizer.models.vqgan.VQGANLitModel\n",
- "network:\n",
- " encoder:\n",
- " _target_: text_recognizer.networks.vqvae.encoder.Encoder\n",
- " in_channels: 1\n",
- " hidden_dim: 32\n",
- " channels_multipliers:\n",
- " - 1\n",
- " - 2\n",
- " - 4\n",
- " - 8\n",
- " - 8\n",
- " dropout_rate: 0.25\n",
- " decoder:\n",
- " _target_: text_recognizer.networks.vqvae.decoder.Decoder\n",
- " out_channels: 1\n",
- " hidden_dim: 32\n",
- " channels_multipliers:\n",
- " - 8\n",
- " - 8\n",
- " - 4\n",
- " - 1\n",
- " dropout_rate: 0.25\n",
- " _target_: text_recognizer.networks.vqvae.vqvae.VQVAE\n",
- " hidden_dim: 256\n",
- " embedding_dim: 32\n",
- " num_embeddings: 1024\n",
- " decay: 0.99\n",
- "trainer:\n",
- " _target_: pytorch_lightning.Trainer\n",
- " stochastic_weight_avg: false\n",
- " auto_scale_batch_size: binsearch\n",
- " auto_lr_find: false\n",
- " gradient_clip_val: 0\n",
- " fast_dev_run: false\n",
- " gpus: 1\n",
- " precision: 16\n",
- " max_epochs: 256\n",
- " terminate_on_nan: true\n",
- " weights_summary: top\n",
- " limit_train_batches: 1.0\n",
- " limit_val_batches: 1.0\n",
- " limit_test_batches: 1.0\n",
- " resume_from_checkpoint: null\n",
- "seed: 4711\n",
- "tune: false\n",
- "train: true\n",
- "test: true\n",
- "logging: INFO\n",
- "work_dir: ${hydra:runtime.cwd}\n",
- "debug: false\n",
- "print_config: false\n",
- "ignore_warnings: true\n",
- "summary: null\n",
- "lr_schedulers:\n",
- " generator:\n",
- " _target_: torch.optim.lr_scheduler.CosineAnnealingLR\n",
- " T_max: 256\n",
- " eta_min: 0.0\n",
- " last_epoch: -1\n",
- " interval: epoch\n",
- " monitor: val/loss\n",
- " discriminator:\n",
- " _target_: torch.optim.lr_scheduler.CosineAnnealingLR\n",
- " T_max: 256\n",
- " eta_min: 0.0\n",
- " last_epoch: -1\n",
- " interval: epoch\n",
- " monitor: val/loss\n",
- "optimizers:\n",
- " generator:\n",
- " _target_: madgrad.MADGRAD\n",
- " lr: 0.001\n",
- " momentum: 0.5\n",
- " weight_decay: 0\n",
- " eps: 1.0e-06\n",
- " parameters: network\n",
- " discriminator:\n",
- " _target_: madgrad.MADGRAD\n",
- " lr: 0.001\n",
- " momentum: 0.5\n",
- " weight_decay: 0\n",
- " eps: 1.0e-06\n",
- " parameters: loss_fn.discriminator\n",
- "\n",
- "{'callbacks': {'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': '{epoch:02d}'}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_image_reconstruction': {'_target_': 'callbacks.wandb_callbacks.LogReconstuctedImages', 'num_samples': 8}}, 'criterion': {'_target_': 'text_recognizer.criterions.vqgan_loss.VQGANLoss', 'reconstruction_loss': {'_target_': 'torch.nn.L1Loss', 'reduction': 'mean'}, 'discriminator': {'_target_': 'text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator', 'in_channels': 1, 'num_channels': 32, 'num_layers': 3}, 'vq_loss_weight': 1.0, 'discriminator_weight': 1.0}, 'datamodule': {'_target_': 'text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs', 'batch_size': 8, 'num_workers': 12, 'train_fraction': 0.8, 'augment': True, 'pin_memory': False, 'word_pieces': True}, 'logger': {'wandb': {'_target_': 'pytorch_lightning.loggers.wandb.WandbLogger', 'project': 'text-recognizer', 'name': None, 'save_dir': '.', 'offline': False, 'id': None, 'log_model': False, 'prefix': '', 'job_type': 'train', 'group': '', 'tags': []}}, 'mapping': {'_target_': 'text_recognizer.data.word_piece_mapping.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}, 'model': {'_target_': 'text_recognizer.models.vqgan.VQGANLitModel'}, 'network': {'encoder': {'_target_': 'text_recognizer.networks.vqvae.encoder.Encoder', 'in_channels': 1, 'hidden_dim': 32, 'channels_multipliers': [1, 2, 4, 8, 8], 'dropout_rate': 0.25}, 'decoder': {'_target_': 'text_recognizer.networks.vqvae.decoder.Decoder', 'out_channels': 1, 'hidden_dim': 32, 'channels_multipliers': [8, 8, 4, 1], 'dropout_rate': 0.25}, '_target_': 'text_recognizer.networks.vqvae.vqvae.VQVAE', 'hidden_dim': 256, 'embedding_dim': 32, 'num_embeddings': 1024, 'decay': 0.99}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'stochastic_weight_avg': False, 'auto_scale_batch_size': 'binsearch', 'auto_lr_find': False, 'gradient_clip_val': 0, 'fast_dev_run': False, 'gpus': 1, 'precision': 16, 'max_epochs': 256, 'terminate_on_nan': True, 'weights_summary': 'top', 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'limit_test_batches': 1.0, 'resume_from_checkpoint': None}, 'seed': 4711, 'tune': False, 'train': True, 'test': True, 'logging': 'INFO', 'work_dir': '${hydra:runtime.cwd}', 'debug': False, 'print_config': False, 'ignore_warnings': True, 'summary': None, 'lr_schedulers': {'generator': {'_target_': 'torch.optim.lr_scheduler.CosineAnnealingLR', 'T_max': 256, 'eta_min': 0.0, 'last_epoch': -1, 'interval': 'epoch', 'monitor': 'val/loss'}, 'discriminator': {'_target_': 'torch.optim.lr_scheduler.CosineAnnealingLR', 'T_max': 256, 'eta_min': 0.0, 'last_epoch': -1, 'interval': 'epoch', 'monitor': 'val/loss'}}, 'optimizers': {'generator': {'_target_': 'madgrad.MADGRAD', 'lr': 0.001, 'momentum': 0.5, 'weight_decay': 0, 'eps': 1e-06, 'parameters': 'network'}, 'discriminator': {'_target_': 'madgrad.MADGRAD', 'lr': 0.001, 'momentum': 0.5, 'weight_decay': 0, 'eps': 1e-06, 'parameters': 'loss_fn.discriminator'}}}\n"
- ]
- }
- ],
- "source": [
- "# context initialization\n",
- "with initialize(config_path=\"../training/conf/\", job_name=\"test_app\"):\n",
- " cfg = compose(config_name=\"config\", overrides=[\"+experiment=vqgan\"])\n",
- " print(OmegaConf.to_yaml(cfg))\n",
- " print(cfg)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "205a03e8-7aa1-407f-afa5-92693715b677",
- "metadata": {},
- "outputs": [],
- "source": [
- "net = instantiate(cfg)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c74384f0-754e-4c29-8f06-339372d6e4c1",
- "metadata": {},
- "outputs": [],
- "source": [
- "from torchsummary import summary"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5ebab599-2497-42f8-b54b-1663ee66fde9",
- "metadata": {},
- "outputs": [],
- "source": [
- "summary(net, (1, 576, 640), device=\"cpu\");"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6ba3f405-5948-465d-a7b8-459c84345034",
- "metadata": {},
- "outputs": [],
- "source": [
- "net = net.cuda()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5c998137-0967-488f-a572-a5f5a6b86353",
- "metadata": {},
- "outputs": [],
- "source": [
- "x = torch.randn(16, 1, 576, 640)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "920aeeb2-088c-4ea0-84a2-a2532d4f697a",
- "metadata": {},
- "outputs": [],
- "source": [
- "x = x.cuda()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "119ab631-fb3a-47a3-afc2-0e66260ebe7f",
- "metadata": {},
- "outputs": [],
- "source": [
- "xx, l = net(x)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "7ccdec29-3952-460d-95b4-820b03aa4997",
- "metadata": {},
- "outputs": [],
- "source": [
- "xx.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a847084a-a65d-4072-ae1e-ae5d85a1664a",
- "metadata": {},
- "outputs": [],
- "source": [
- "l"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "9b21480a-707b-41de-b75d-30fb467973a4",
- "metadata": {},
- "outputs": [],
- "source": [
- "vq(x)[0].shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "cba1096d-8832-4955-88c9-a8650cf968cf",
- "metadata": {},
- "outputs": [],
- "source": [
- "import os"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "443a52d9-09f3-4e24-8a23-e0397a65f747",
- "metadata": {},
- "outputs": [],
- "source": [
- "import glob"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "78541477-6f02-42da-ad75-4a47bb043e79",
- "metadata": {},
- "outputs": [],
- "source": [
- "from pathlib import Path"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "bdedced3-e08b-4bec-822c-e5dcd521c6b8",
- "metadata": {},
- "outputs": [],
- "source": [
- "list(Path(code_dir).glob(\"**/*.py\"))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "79771541-c474-46a9-afdf-f74e736d6c16",
- "metadata": {},
- "outputs": [],
- "source": [
- "for path in glob.glob(os.path.join(code_dir, \"**/*.py\"), recursive=True):\n",
- " print(path)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a79a2a20-56df-48b3-b964-22a0def52117",
- "metadata": {},
- "outputs": [],
- "source": [
- "e = Encoder(1, 64, 32, 0.2)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5a6fd004-6d7c-4a20-9ed4-508a73b329b2",
- "metadata": {},
- "outputs": [],
- "source": [
- "d = Decoder(64, 1, 32, 0.2)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "82c18401-ea33-4ab6-ace4-03cb6e2e4435",
- "metadata": {},
- "outputs": [],
- "source": [
- "z = e(x)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "64f99b20-fa37-4614-b258-5870b7668959",
- "metadata": {},
- "outputs": [],
- "source": [
- "xh = d(z)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4a81e7de-1203-4ab6-9562-37341e135daf",
- "metadata": {},
- "outputs": [],
- "source": [
- "xh.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "204d167b-dce0-4dd7-b0e1-88a53859fd28",
- "metadata": {},
- "outputs": [],
- "source": [
- "a = [2, 2]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b77a6e8a-070d-46d3-9470-a5729eace57f",
- "metadata": {},
- "outputs": [],
- "source": [
- "a += [1, 1]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "741adac8-acc4-4715-afe9-07d3522cab62",
- "metadata": {},
- "outputs": [],
- "source": [
- "a"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "49b894be-5947-4e06-b698-bb990bf2c64c",
- "metadata": {},
- "outputs": [],
- "source": [
- "x"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4371af97-1f3b-4c5e-9812-3fb97d07c1cb",
- "metadata": {},
- "outputs": [],
- "source": [
- "576 // (2 * 4)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "28224cc8-79e0-481f-b24c-85bd0ef69f0a",
- "metadata": {},
- "outputs": [],
- "source": [
- "16 // 2"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "d3a6146b-94b1-4618-a4e4-00f8e23ffdb0",
- "metadata": {},
- "outputs": [],
- "source": [
- "from hydra import compose, initialize\n",
- "from omegaconf import OmegaConf\n",
- "from hydra.utils import instantiate"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "764c8736-7d68-4261-a57d-face10ebbf42",
- "metadata": {},
- "outputs": [],
- "source": [
- "# context initialization\n",
- "with initialize(config_path=\"../training/conf/\", job_name=\"test_app\"):\n",
- " cfg = compose(config_name=\"config\", overrides=[\"+experiment=vqvae\"])\n",
- " print(OmegaConf.to_yaml(cfg))\n",
- " print(cfg)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c1a9aa6b-6405-4ffe-b065-02340762476a",
- "metadata": {},
- "outputs": [],
- "source": [
- "mapping = instantiate(cfg.mapping)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "969ba3be-d78f-4b1e-b522-ea8a42669e86",
- "metadata": {},
- "outputs": [],
- "source": [
- "network = instantiate(cfg.network)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6147cd3e-0ad1-490f-917d-21be9bb8ce1c",
- "metadata": {},
- "outputs": [],
- "source": [
- "x = torch.rand(1, 1, 576, 640)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a0ecea0c-abaf-4d5d-a13d-c085c1e4d282",
- "metadata": {},
- "outputs": [],
- "source": [
- "network.encode(x)[0].shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a7b9f249-7e5e-4f31-bbe1-cfd6d3701cf0",
- "metadata": {},
- "outputs": [],
- "source": [
- "t, l = network(x)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "9a9450d2-f45d-4823-adac-68a8ea05ed1d",
- "metadata": {},
- "outputs": [],
- "source": [
- "l"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "93b8c90f-788a-4095-aa7a-55b34f0ddaaf",
- "metadata": {},
- "outputs": [],
- "source": [
- "from torch.nn import functional as F\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c9983788-2dae-4375-a821-a64cd1c68edf",
- "metadata": {},
- "outputs": [],
- "source": [
- "F.mse_loss(x, t) + l"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "29b128ca-80b7-481e-bb3c-44f109c7d292",
- "metadata": {},
- "outputs": [],
- "source": [
- "t.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "23c9d90c-042b-423e-ab85-18449e29ded4",
- "metadata": {},
- "outputs": [],
- "source": [
- "576 / 4"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "047ebc09-1c74-44a7-a314-1099f09722fe",
- "metadata": {},
- "outputs": [],
- "source": [
- "t = torch.randint(0, 1006, (1, 451)).cuda()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "87372dde-2b1a-432b-ab79-0b116124c724",
- "metadata": {},
- "outputs": [],
- "source": [
- "z = torch.rand((1, 36 * 40, 128)).cuda()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "cf7ca9bf-cafa-4128-9db7-046c16933a52",
- "metadata": {},
- "outputs": [],
- "source": [
- "network = network.cuda()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "dfceaa5f-9ad8-4d33-addb-c56e8da48356",
- "metadata": {},
- "outputs": [],
- "source": [
- "network.decode(z, t).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "9105fbbb-4363-4d3e-a01e-bc519c3b9c3a",
- "metadata": {},
- "outputs": [],
- "source": [
- "decoder = decoder.cuda()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c5797ec4-7a6a-46fd-8adc-265df44d0341",
- "metadata": {},
- "outputs": [],
- "source": [
- "decoder(z, t).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a23893a9-a0da-4327-a617-dc0c2011e5e8",
- "metadata": {},
- "outputs": [],
- "source": [
- "OmegaConf.set_struct(cfg, False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a6fae1fa-492d-4648-80fd-1c0dac659b02",
- "metadata": {},
- "outputs": [],
- "source": [
- "datamodule = instantiate(cfg.datamodule, mapping=mapping)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "514053ef-fcac-4f3c-a7c8-72c6927d6798",
- "metadata": {},
- "outputs": [],
- "source": [
- "datamodule.prepare_data()\n",
- "datamodule.setup()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4bad950b-a197-4c60-ad89-903124659a98",
- "metadata": {},
- "outputs": [],
- "source": [
- "len(datamodule.train_dataloader())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "7db05cbd-48b3-43fa-a99a-353126311879",
- "metadata": {},
- "outputs": [],
- "source": [
- "mapping"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "f6e01c15-9a1b-4036-87ae-78716c592264",
- "metadata": {},
- "outputs": [],
- "source": [
- "config = cfg"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4dc475fc-31f4-487e-88c8-b0f445131f5b",
- "metadata": {},
- "outputs": [],
- "source": [
- "loss_fn = instantiate(cfg.criterion)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c5c8ed64-d98c-47b5-baf2-1ba57a6c882f",
- "metadata": {},
- "outputs": [],
- "source": [
- "import hydra"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b5ff5b24-f804-402b-a8ab-f366443025ca",
- "metadata": {},
- "outputs": [],
- "source": [
- " model = hydra.utils.instantiate(\n",
- " config.model,\n",
- " mapping=mapping,\n",
- " network=network,\n",
- " loss_fn=loss_fn,\n",
- " optimizer_config=config.optimizer,\n",
- " lr_scheduler_config=config.lr_scheduler,\n",
- " _recursive_=False,\n",
- " )\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "99f8a39f-8b10-4f7d-8bff-52794fd48717",
- "metadata": {},
- "outputs": [],
- "source": [
- "mapping.get_index"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "af2c8cfa-0b45-4681-b671-0f97ace62516",
- "metadata": {},
- "outputs": [],
- "source": [
- "net = instantiate(cfg)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "8f0742ad-5e2f-42d5-83e7-6e46398b4f0f",
- "metadata": {},
- "outputs": [],
- "source": [
- "net"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "40be59bc-db79-4af1-9df4-e280f7a56481",
- "metadata": {},
- "outputs": [],
- "source": [
- "img = torch.rand(4, 1, 576, 640)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "d5a8f10b-edf5-4a18-9747-f016db72c384",
- "metadata": {},
- "outputs": [],
- "source": [
- "y = torch.randint(0, 1006, (4, 451))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "19423ef1-3d98-4af3-8748-fdd3bb817300",
- "metadata": {},
- "outputs": [],
- "source": [
- "y.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "0712ee7e-4f66-4fb1-bc91-d8a127eb7ac7",
- "metadata": {},
- "outputs": [],
- "source": [
- "net = net.cuda()\n",
- "img = img.cuda()\n",
- "y = y.cuda()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "719154b4-47db-4c91-bae4-8c572c4a4536",
- "metadata": {},
- "outputs": [],
- "source": [
- "net(img, y).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "bcb7db0f-0afe-44eb-9bb7-b988fbead95a",
- "metadata": {},
- "outputs": [],
- "source": [
- "from torchsummary import summary"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "31af8ee1-28d3-46b8-a847-6506d29bc45c",
- "metadata": {},
- "outputs": [],
- "source": [
- "summary(net, [(1, 576, 640), (451,)], device=\"cpu\", depth=2)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4d6d836f-d169-48b4-92e6-ca17179e6f85",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.6"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}