{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "1e40a88b", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from PIL import Image\n", "import torch\n", "from torch import nn\n", "from importlib.util import find_spec\n", "if find_spec(\"text_recognizer\") is None:\n", " import sys\n", " sys.path.append('..')\n", " " ] }, { "cell_type": "code", "execution_count": 2, "id": "d3a6146b-94b1-4618-a4e4-00f8e23ffdb0", "metadata": {}, "outputs": [], "source": [ "from hydra import compose, initialize\n", "from omegaconf import OmegaConf\n", "from hydra.utils import instantiate" ] }, { "cell_type": "code", "execution_count": 3, "id": "764c8736-7d68-4261-a57d-face10ebbf42", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "callbacks:\n", " model_checkpoint:\n", " _target_: pytorch_lightning.callbacks.ModelCheckpoint\n", " monitor: val/loss\n", " save_top_k: 1\n", " save_last: true\n", " mode: min\n", " verbose: false\n", " dirpath: checkpoints/\n", " filename: '{epoch:02d}'\n", " learning_rate_monitor:\n", " _target_: pytorch_lightning.callbacks.LearningRateMonitor\n", " logging_interval: step\n", " log_momentum: false\n", " watch_model:\n", " _target_: callbacks.wandb_callbacks.WatchModel\n", " log: all\n", " log_freq: 100\n", " upload_code_as_artifact:\n", " _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact\n", " project_dir: ${work_dir}/text_recognizer\n", " upload_ckpts_as_artifact:\n", " _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n", " ckpt_dir: checkpoints/\n", " upload_best_only: true\n", " log_image_reconstruction:\n", " _target_: callbacks.wandb_callbacks.LogReconstuctedImages\n", " num_samples: 8\n", "criterion:\n", " _target_: torch.nn.MSELoss\n", " reduction: mean\n", "datamodule:\n", " _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs\n", " batch_size: 32\n", " num_workers: 12\n", " train_fraction: 0.8\n", " augment: true\n", " pin_memory: false\n", " word_pieces: true\n", "logger:\n", " wandb:\n", " _target_: pytorch_lightning.loggers.wandb.WandbLogger\n", " project: text-recognizer\n", " name: null\n", " save_dir: .\n", " offline: false\n", " id: null\n", " log_model: false\n", " prefix: ''\n", " job_type: train\n", " group: ''\n", " tags: []\n", "lr_scheduler:\n", " _target_: torch.optim.lr_scheduler.OneCycleLR\n", " max_lr: 0.001\n", " total_steps: null\n", " epochs: 64\n", " steps_per_epoch: 624\n", " pct_start: 0.3\n", " anneal_strategy: cos\n", " cycle_momentum: true\n", " base_momentum: 0.85\n", " max_momentum: 0.95\n", " div_factor: 25.0\n", " final_div_factor: 10000.0\n", " three_phase: true\n", " last_epoch: -1\n", " verbose: false\n", "mapping:\n", " _target_: text_recognizer.data.word_piece_mapping.WordPieceMapping\n", " num_features: 1000\n", " tokens: iamdb_1kwp_tokens_1000.txt\n", " lexicon: iamdb_1kwp_lex_1000.txt\n", " data_dir: null\n", " use_words: false\n", " prepend_wordsep: false\n", " special_tokens:\n", " - \n", " - \n", " -

\n", " extra_symbols:\n", " - '\n", "\n", " '\n", "model:\n", " _target_: text_recognizer.models.vqvae.VQVAELitModel\n", " interval: step\n", " monitor: val/loss\n", "network:\n", " _target_: text_recognizer.networks.vqvae.VQVAE\n", " in_channels: 1\n", " res_channels: 32\n", " num_residual_layers: 2\n", " embedding_dim: 64\n", " num_embeddings: 512\n", " decay: 0.99\n", " activation: mish\n", "optimizer:\n", " _target_: madgrad.MADGRAD\n", " lr: 0.01\n", " momentum: 0.9\n", " weight_decay: 0\n", " eps: 1.0e-06\n", "trainer:\n", " _target_: pytorch_lightning.Trainer\n", " stochastic_weight_avg: false\n", " auto_scale_batch_size: binsearch\n", " auto_lr_find: false\n", " gradient_clip_val: 0\n", " fast_dev_run: false\n", " gpus: 1\n", " precision: 16\n", " max_epochs: 64\n", " terminate_on_nan: true\n", " weights_summary: top\n", " limit_train_batches: 1.0\n", " limit_val_batches: 1.0\n", " limit_test_batches: 1.0\n", " resume_from_checkpoint: null\n", "seed: 4711\n", "tune: false\n", "train: true\n", "test: true\n", "logging: INFO\n", "work_dir: ${hydra:runtime.cwd}\n", "debug: false\n", "print_config: true\n", "ignore_warnings: true\n", "\n", "{'callbacks': {'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': '{epoch:02d}'}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_code_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCodeAsArtifact', 'project_dir': '${work_dir}/text_recognizer'}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_image_reconstruction': {'_target_': 'callbacks.wandb_callbacks.LogReconstuctedImages', 'num_samples': 8}}, 'criterion': {'_target_': 'torch.nn.MSELoss', 'reduction': 'mean'}, 'datamodule': {'_target_': 'text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs', 'batch_size': 32, 'num_workers': 12, 'train_fraction': 0.8, 'augment': True, 'pin_memory': False, 'word_pieces': True}, 'logger': {'wandb': {'_target_': 'pytorch_lightning.loggers.wandb.WandbLogger', 'project': 'text-recognizer', 'name': None, 'save_dir': '.', 'offline': False, 'id': None, 'log_model': False, 'prefix': '', 'job_type': 'train', 'group': '', 'tags': []}}, 'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.OneCycleLR', 'max_lr': 0.001, 'total_steps': None, 'epochs': 64, 'steps_per_epoch': 624, 'pct_start': 0.3, 'anneal_strategy': 'cos', 'cycle_momentum': True, 'base_momentum': 0.85, 'max_momentum': 0.95, 'div_factor': 25.0, 'final_div_factor': 10000.0, 'three_phase': True, 'last_epoch': -1, 'verbose': False}, 'mapping': {'_target_': 'text_recognizer.data.word_piece_mapping.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['', '', '

'], 'extra_symbols': ['\\n']}, 'model': {'_target_': 'text_recognizer.models.vqvae.VQVAELitModel', 'interval': 'step', 'monitor': 'val/loss'}, 'network': {'_target_': 'text_recognizer.networks.vqvae.VQVAE', 'in_channels': 1, 'res_channels': 32, 'num_residual_layers': 2, 'embedding_dim': 64, 'num_embeddings': 512, 'decay': 0.99, 'activation': 'mish'}, 'optimizer': {'_target_': 'madgrad.MADGRAD', 'lr': 0.01, 'momentum': 0.9, 'weight_decay': 0, 'eps': 1e-06}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'stochastic_weight_avg': False, 'auto_scale_batch_size': 'binsearch', 'auto_lr_find': False, 'gradient_clip_val': 0, 'fast_dev_run': False, 'gpus': 1, 'precision': 16, 'max_epochs': 64, 'terminate_on_nan': True, 'weights_summary': 'top', 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'limit_test_batches': 1.0, 'resume_from_checkpoint': None}, 'seed': 4711, 'tune': False, 'train': True, 'test': True, 'logging': 'INFO', 'work_dir': '${hydra:runtime.cwd}', 'debug': False, 'print_config': True, 'ignore_warnings': True}\n" ] } ], "source": [ "# context initialization\n", "with initialize(config_path=\"../training/conf/\", job_name=\"test_app\"):\n", " cfg = compose(config_name=\"config\", overrides=[\"+experiment=vqvae\"])\n", " print(OmegaConf.to_yaml(cfg))\n", " print(cfg)" ] }, { "cell_type": "code", "execution_count": 4, "id": "c1a9aa6b-6405-4ffe-b065-02340762476a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2021-08-04 04:49:04.188 | DEBUG | text_recognizer.data.word_piece_mapping:__init__:37 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n" ] } ], "source": [ "mapping = instantiate(cfg.mapping)" ] }, { "cell_type": "code", "execution_count": 35, "id": "969ba3be-d78f-4b1e-b522-ea8a42669e86", "metadata": {}, "outputs": [], "source": [ "network = instantiate(cfg.network)" ] }, { "cell_type": "code", "execution_count": 36, "id": "6147cd3e-0ad1-490f-917d-21be9bb8ce1c", "metadata": {}, "outputs": [], "source": [ "x = torch.rand(1, 1, 576, 640)" ] }, { "cell_type": "code", "execution_count": 37, "id": "a0ecea0c-abaf-4d5d-a13d-c085c1e4d282", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 64, 144, 160])" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "network.encode(x)[0].shape" ] }, { "cell_type": "code", "execution_count": 38, "id": "a7b9f249-7e5e-4f31-bbe1-cfd6d3701cf0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([512])\n", "torch.Size([512])\n", "torch.Size([512])\n", "torch.Size([512])\n" ] }, { "data": { "text/plain": [ "torch.Size([1, 1, 576, 640])" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "network(x)[0].shape" ] }, { "cell_type": "code", "execution_count": null, "id": "23c9d90c-042b-423e-ab85-18449e29ded4", "metadata": {}, "outputs": [], "source": [ "576 / 4" ] }, { "cell_type": "code", "execution_count": null, "id": "047ebc09-1c74-44a7-a314-1099f09722fe", "metadata": {}, "outputs": [], "source": [ "t = torch.randint(0, 1006, (1, 451)).cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "87372dde-2b1a-432b-ab79-0b116124c724", "metadata": {}, "outputs": [], "source": [ "z = torch.rand((1, 36 * 40, 128)).cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "cf7ca9bf-cafa-4128-9db7-046c16933a52", "metadata": {}, "outputs": [], "source": [ "network = network.cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "dfceaa5f-9ad8-4d33-addb-c56e8da48356", "metadata": {}, "outputs": [], "source": [ "network.decode(z, t).shape" ] }, { "cell_type": "code", "execution_count": null, "id": "9105fbbb-4363-4d3e-a01e-bc519c3b9c3a", "metadata": {}, "outputs": [], "source": [ "decoder = decoder.cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "c5797ec4-7a6a-46fd-8adc-265df44d0341", "metadata": {}, "outputs": [], "source": [ "decoder(z, t).shape" ] }, { "cell_type": "code", "execution_count": null, "id": "a23893a9-a0da-4327-a617-dc0c2011e5e8", "metadata": {}, "outputs": [], "source": [ "OmegaConf.set_struct(cfg, False)" ] }, { "cell_type": "code", "execution_count": null, "id": "a6fae1fa-492d-4648-80fd-1c0dac659b02", "metadata": {}, "outputs": [], "source": [ "datamodule = instantiate(cfg.datamodule, mapping=mapping)" ] }, { "cell_type": "code", "execution_count": null, "id": "514053ef-fcac-4f3c-a7c8-72c6927d6798", "metadata": {}, "outputs": [], "source": [ "datamodule.prepare_data()\n", "datamodule.setup()" ] }, { "cell_type": "code", "execution_count": null, "id": "4bad950b-a197-4c60-ad89-903124659a98", "metadata": {}, "outputs": [], "source": [ "len(datamodule.train_dataloader())" ] }, { "cell_type": "code", "execution_count": null, "id": "7db05cbd-48b3-43fa-a99a-353126311879", "metadata": {}, "outputs": [], "source": [ "mapping" ] }, { "cell_type": "code", "execution_count": null, "id": "f6e01c15-9a1b-4036-87ae-78716c592264", "metadata": {}, "outputs": [], "source": [ "config = cfg" ] }, { "cell_type": "code", "execution_count": null, "id": "4dc475fc-31f4-487e-88c8-b0f445131f5b", "metadata": {}, "outputs": [], "source": [ "loss_fn = instantiate(cfg.criterion)" ] }, { "cell_type": "code", "execution_count": null, "id": "c5c8ed64-d98c-47b5-baf2-1ba57a6c882f", "metadata": {}, "outputs": [], "source": [ "import hydra" ] }, { "cell_type": "code", "execution_count": null, "id": "b5ff5b24-f804-402b-a8ab-f366443025ca", "metadata": {}, "outputs": [], "source": [ " model = hydra.utils.instantiate(\n", " config.model,\n", " mapping=mapping,\n", " network=network,\n", " loss_fn=loss_fn,\n", " optimizer_config=config.optimizer,\n", " lr_scheduler_config=config.lr_scheduler,\n", " _recursive_=False,\n", " )\n" ] }, { "cell_type": "code", "execution_count": null, "id": "99f8a39f-8b10-4f7d-8bff-52794fd48717", "metadata": {}, "outputs": [], "source": [ "mapping.get_index" ] }, { "cell_type": "code", "execution_count": null, "id": "af2c8cfa-0b45-4681-b671-0f97ace62516", "metadata": {}, "outputs": [], "source": [ "net = instantiate(cfg)" ] }, { "cell_type": "code", "execution_count": null, "id": "8f0742ad-5e2f-42d5-83e7-6e46398b4f0f", "metadata": {}, "outputs": [], "source": [ "net" ] }, { "cell_type": "code", "execution_count": null, "id": "40be59bc-db79-4af1-9df4-e280f7a56481", "metadata": {}, "outputs": [], "source": [ "img = torch.rand(4, 1, 576, 640)" ] }, { "cell_type": "code", "execution_count": null, "id": "d5a8f10b-edf5-4a18-9747-f016db72c384", "metadata": {}, "outputs": [], "source": [ "y = torch.randint(0, 1006, (4, 451))" ] }, { "cell_type": "code", "execution_count": null, "id": "19423ef1-3d98-4af3-8748-fdd3bb817300", "metadata": {}, "outputs": [], "source": [ "y.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "0712ee7e-4f66-4fb1-bc91-d8a127eb7ac7", "metadata": {}, "outputs": [], "source": [ "net = net.cuda()\n", "img = img.cuda()\n", "y = y.cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "719154b4-47db-4c91-bae4-8c572c4a4536", "metadata": {}, "outputs": [], "source": [ "net(img, y).shape" ] }, { "cell_type": "code", "execution_count": null, "id": "bcb7db0f-0afe-44eb-9bb7-b988fbead95a", "metadata": {}, "outputs": [], "source": [ "from torchsummary import summary" ] }, { "cell_type": "code", "execution_count": null, "id": "31af8ee1-28d3-46b8-a847-6506d29bc45c", "metadata": {}, "outputs": [], "source": [ "summary(net, [(1, 576, 640), (451,)], device=\"cpu\", depth=2)" ] }, { "cell_type": "code", "execution_count": null, "id": "4d6d836f-d169-48b4-92e6-ca17179e6f85", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 5 }