{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "1e40a88b", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from PIL import Image\n", "import torch\n", "from torch import nn\n", "from importlib.util import find_spec\n", "if find_spec(\"text_recognizer\") is None:\n", " import sys\n", " sys.path.append('..')\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "38fb3d9d-a163-4b72-981f-f31b51be39f2", "metadata": {}, "outputs": [], "source": [ "from hydra import compose, initialize\n", "from omegaconf import OmegaConf\n", "from hydra.utils import instantiate" ] }, { "cell_type": "code", "execution_count": null, "id": "74780b21-3313-452b-b580-703cac878416", "metadata": {}, "outputs": [], "source": [ "# context initialization\n", "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n", " cfg = compose(config_name=\"vqvae\")\n", " print(OmegaConf.to_yaml(cfg))\n", " print(cfg)" ] }, { "cell_type": "code", "execution_count": null, "id": "205a03e8-7aa1-407f-afa5-92693715b677", "metadata": {}, "outputs": [], "source": [ "net = instantiate(cfg)" ] }, { "cell_type": "code", "execution_count": null, "id": "c74384f0-754e-4c29-8f06-339372d6e4c1", "metadata": {}, "outputs": [], "source": [ "from torchsummary import summary" ] }, { "cell_type": "code", "execution_count": null, "id": "5ebab599-2497-42f8-b54b-1663ee66fde9", "metadata": {}, "outputs": [], "source": [ "summary(net, (1, 576, 640), device=\"cpu\");" ] }, { "cell_type": "code", "execution_count": null, "id": "6ba3f405-5948-465d-a7b8-459c84345034", "metadata": {}, "outputs": [], "source": [ "net = net.cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "5c998137-0967-488f-a572-a5f5a6b86353", "metadata": {}, "outputs": [], "source": [ "x = torch.randn(16, 1, 576, 640)" ] }, { "cell_type": "code", "execution_count": null, "id": "920aeeb2-088c-4ea0-84a2-a2532d4f697a", "metadata": {}, "outputs": [], "source": [ "x = x.cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "119ab631-fb3a-47a3-afc2-0e66260ebe7f", "metadata": {}, "outputs": [], "source": [ "xx, l = net(x)" ] }, { "cell_type": "code", "execution_count": null, "id": "7ccdec29-3952-460d-95b4-820b03aa4997", "metadata": {}, "outputs": [], "source": [ "xx.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "a847084a-a65d-4072-ae1e-ae5d85a1664a", "metadata": {}, "outputs": [], "source": [ "l" ] }, { "cell_type": "code", "execution_count": null, "id": "9b21480a-707b-41de-b75d-30fb467973a4", "metadata": {}, "outputs": [], "source": [ "vq(x)[0].shape" ] }, { "cell_type": "code", "execution_count": null, "id": "cba1096d-8832-4955-88c9-a8650cf968cf", "metadata": {}, "outputs": [], "source": [ "import os" ] }, { "cell_type": "code", "execution_count": null, "id": "443a52d9-09f3-4e24-8a23-e0397a65f747", "metadata": {}, "outputs": [], "source": [ "import glob" ] }, { "cell_type": "code", "execution_count": null, "id": "78541477-6f02-42da-ad75-4a47bb043e79", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path" ] }, { "cell_type": "code", "execution_count": null, "id": "bdedced3-e08b-4bec-822c-e5dcd521c6b8", "metadata": {}, "outputs": [], "source": [ "list(Path(code_dir).glob(\"**/*.py\"))" ] }, { "cell_type": "code", "execution_count": null, "id": "79771541-c474-46a9-afdf-f74e736d6c16", "metadata": {}, "outputs": [], "source": [ "for path in glob.glob(os.path.join(code_dir, \"**/*.py\"), recursive=True):\n", " print(path)" ] }, { "cell_type": "code", "execution_count": null, "id": "a79a2a20-56df-48b3-b964-22a0def52117", "metadata": {}, "outputs": [], "source": [ "e = Encoder(1, 64, 32, 0.2)" ] }, { "cell_type": "code", "execution_count": null, "id": "5a6fd004-6d7c-4a20-9ed4-508a73b329b2", "metadata": {}, "outputs": [], "source": [ "d = Decoder(64, 1, 32, 0.2)" ] }, { "cell_type": "code", "execution_count": null, "id": "82c18401-ea33-4ab6-ace4-03cb6e2e4435", "metadata": {}, "outputs": [], "source": [ "z = e(x)" ] }, { "cell_type": "code", "execution_count": null, "id": "64f99b20-fa37-4614-b258-5870b7668959", "metadata": {}, "outputs": [], "source": [ "xh = d(z)" ] }, { "cell_type": "code", "execution_count": null, "id": "4a81e7de-1203-4ab6-9562-37341e135daf", "metadata": {}, "outputs": [], "source": [ "xh.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "204d167b-dce0-4dd7-b0e1-88a53859fd28", "metadata": {}, "outputs": [], "source": [ "a = [2, 2]" ] }, { "cell_type": "code", "execution_count": null, "id": "b77a6e8a-070d-46d3-9470-a5729eace57f", "metadata": {}, "outputs": [], "source": [ "a += [1, 1]" ] }, { "cell_type": "code", "execution_count": null, "id": "741adac8-acc4-4715-afe9-07d3522cab62", "metadata": {}, "outputs": [], "source": [ "a" ] }, { "cell_type": "code", "execution_count": null, "id": "49b894be-5947-4e06-b698-bb990bf2c64c", "metadata": {}, "outputs": [], "source": [ "x" ] }, { "cell_type": "code", "execution_count": null, "id": "4371af97-1f3b-4c5e-9812-3fb97d07c1cb", "metadata": {}, "outputs": [], "source": [ "576 // (2 * 4)" ] }, { "cell_type": "code", "execution_count": null, "id": "28224cc8-79e0-481f-b24c-85bd0ef69f0a", "metadata": {}, "outputs": [], "source": [ "16 // 2" ] }, { "cell_type": "code", "execution_count": null, "id": "d3a6146b-94b1-4618-a4e4-00f8e23ffdb0", "metadata": {}, "outputs": [], "source": [ "from hydra import compose, initialize\n", "from omegaconf import OmegaConf\n", "from hydra.utils import instantiate" ] }, { "cell_type": "code", "execution_count": null, "id": "764c8736-7d68-4261-a57d-face10ebbf42", "metadata": {}, "outputs": [], "source": [ "# context initialization\n", "with initialize(config_path=\"../training/conf/\", job_name=\"test_app\"):\n", " cfg = compose(config_name=\"config\", overrides=[\"+experiment=vqvae\"])\n", " print(OmegaConf.to_yaml(cfg))\n", " print(cfg)" ] }, { "cell_type": "code", "execution_count": null, "id": "c1a9aa6b-6405-4ffe-b065-02340762476a", "metadata": {}, "outputs": [], "source": [ "mapping = instantiate(cfg.mapping)" ] }, { "cell_type": "code", "execution_count": null, "id": "969ba3be-d78f-4b1e-b522-ea8a42669e86", "metadata": {}, "outputs": [], "source": [ "network = instantiate(cfg.network)" ] }, { "cell_type": "code", "execution_count": null, "id": "6147cd3e-0ad1-490f-917d-21be9bb8ce1c", "metadata": {}, "outputs": [], "source": [ "x = torch.rand(1, 1, 576, 640)" ] }, { "cell_type": "code", "execution_count": null, "id": "a0ecea0c-abaf-4d5d-a13d-c085c1e4d282", "metadata": {}, "outputs": [], "source": [ "network.encode(x)[0].shape" ] }, { "cell_type": "code", "execution_count": null, "id": "a7b9f249-7e5e-4f31-bbe1-cfd6d3701cf0", "metadata": {}, "outputs": [], "source": [ "t, l = network(x)" ] }, { "cell_type": "code", "execution_count": null, "id": "9a9450d2-f45d-4823-adac-68a8ea05ed1d", "metadata": {}, "outputs": [], "source": [ "l" ] }, { "cell_type": "code", "execution_count": null, "id": "93b8c90f-788a-4095-aa7a-55b34f0ddaaf", "metadata": {}, "outputs": [], "source": [ "from torch.nn import functional as F\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c9983788-2dae-4375-a821-a64cd1c68edf", "metadata": {}, "outputs": [], "source": [ "F.mse_loss(x, t) + l" ] }, { "cell_type": "code", "execution_count": null, "id": "29b128ca-80b7-481e-bb3c-44f109c7d292", "metadata": {}, "outputs": [], "source": [ "t.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "23c9d90c-042b-423e-ab85-18449e29ded4", "metadata": {}, "outputs": [], "source": [ "576 / 4" ] }, { "cell_type": "code", "execution_count": null, "id": "047ebc09-1c74-44a7-a314-1099f09722fe", "metadata": {}, "outputs": [], "source": [ "t = torch.randint(0, 1006, (1, 451)).cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "87372dde-2b1a-432b-ab79-0b116124c724", "metadata": {}, "outputs": [], "source": [ "z = torch.rand((1, 36 * 40, 128)).cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "cf7ca9bf-cafa-4128-9db7-046c16933a52", "metadata": {}, "outputs": [], "source": [ "network = network.cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "dfceaa5f-9ad8-4d33-addb-c56e8da48356", "metadata": {}, "outputs": [], "source": [ "network.decode(z, t).shape" ] }, { "cell_type": "code", "execution_count": null, "id": "9105fbbb-4363-4d3e-a01e-bc519c3b9c3a", "metadata": {}, "outputs": [], "source": [ "decoder = decoder.cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "c5797ec4-7a6a-46fd-8adc-265df44d0341", "metadata": {}, "outputs": [], "source": [ "decoder(z, t).shape" ] }, { "cell_type": "code", "execution_count": null, "id": "a23893a9-a0da-4327-a617-dc0c2011e5e8", "metadata": {}, "outputs": [], "source": [ "OmegaConf.set_struct(cfg, False)" ] }, { "cell_type": "code", "execution_count": null, "id": "a6fae1fa-492d-4648-80fd-1c0dac659b02", "metadata": {}, "outputs": [], "source": [ "datamodule = instantiate(cfg.datamodule, mapping=mapping)" ] }, { "cell_type": "code", "execution_count": null, "id": "514053ef-fcac-4f3c-a7c8-72c6927d6798", "metadata": {}, "outputs": [], "source": [ "datamodule.prepare_data()\n", "datamodule.setup()" ] }, { "cell_type": "code", "execution_count": null, "id": "4bad950b-a197-4c60-ad89-903124659a98", "metadata": {}, "outputs": [], "source": [ "len(datamodule.train_dataloader())" ] }, { "cell_type": "code", "execution_count": null, "id": "7db05cbd-48b3-43fa-a99a-353126311879", "metadata": {}, "outputs": [], "source": [ "mapping" ] }, { "cell_type": "code", "execution_count": null, "id": "f6e01c15-9a1b-4036-87ae-78716c592264", "metadata": {}, "outputs": [], "source": [ "config = cfg" ] }, { "cell_type": "code", "execution_count": null, "id": "4dc475fc-31f4-487e-88c8-b0f445131f5b", "metadata": {}, "outputs": [], "source": [ "loss_fn = instantiate(cfg.criterion)" ] }, { "cell_type": "code", "execution_count": null, "id": "c5c8ed64-d98c-47b5-baf2-1ba57a6c882f", "metadata": {}, "outputs": [], "source": [ "import hydra" ] }, { "cell_type": "code", "execution_count": null, "id": "b5ff5b24-f804-402b-a8ab-f366443025ca", "metadata": {}, "outputs": [], "source": [ " model = hydra.utils.instantiate(\n", " config.model,\n", " mapping=mapping,\n", " network=network,\n", " loss_fn=loss_fn,\n", " optimizer_config=config.optimizer,\n", " lr_scheduler_config=config.lr_scheduler,\n", " _recursive_=False,\n", " )\n" ] }, { "cell_type": "code", "execution_count": null, "id": "99f8a39f-8b10-4f7d-8bff-52794fd48717", "metadata": {}, "outputs": [], "source": [ "mapping.get_index" ] }, { "cell_type": "code", "execution_count": null, "id": "af2c8cfa-0b45-4681-b671-0f97ace62516", "metadata": {}, "outputs": [], "source": [ "net = instantiate(cfg)" ] }, { "cell_type": "code", "execution_count": null, "id": "8f0742ad-5e2f-42d5-83e7-6e46398b4f0f", "metadata": {}, "outputs": [], "source": [ "net" ] }, { "cell_type": "code", "execution_count": null, "id": "40be59bc-db79-4af1-9df4-e280f7a56481", "metadata": {}, "outputs": [], "source": [ "img = torch.rand(4, 1, 576, 640)" ] }, { "cell_type": "code", "execution_count": null, "id": "d5a8f10b-edf5-4a18-9747-f016db72c384", "metadata": {}, "outputs": [], "source": [ "y = torch.randint(0, 1006, (4, 451))" ] }, { "cell_type": "code", "execution_count": null, "id": "19423ef1-3d98-4af3-8748-fdd3bb817300", "metadata": {}, "outputs": [], "source": [ "y.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "0712ee7e-4f66-4fb1-bc91-d8a127eb7ac7", "metadata": {}, "outputs": [], "source": [ "net = net.cuda()\n", "img = img.cuda()\n", "y = y.cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "719154b4-47db-4c91-bae4-8c572c4a4536", "metadata": {}, "outputs": [], "source": [ "net(img, y).shape" ] }, { "cell_type": "code", "execution_count": null, "id": "bcb7db0f-0afe-44eb-9bb7-b988fbead95a", "metadata": {}, "outputs": [], "source": [ "from torchsummary import summary" ] }, { "cell_type": "code", "execution_count": null, "id": "31af8ee1-28d3-46b8-a847-6506d29bc45c", "metadata": {}, "outputs": [], "source": [ "summary(net, [(1, 576, 640), (451,)], device=\"cpu\", depth=2)" ] }, { "cell_type": "code", "execution_count": null, "id": "4d6d836f-d169-48b4-92e6-ca17179e6f85", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 5 }