summaryrefslogtreecommitdiff
path: root/notebooks/04-vqvae.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks/04-vqvae.ipynb')
-rw-r--r--notebooks/04-vqvae.ipynb233
1 files changed, 233 insertions, 0 deletions
diff --git a/notebooks/04-vqvae.ipynb b/notebooks/04-vqvae.ipynb
new file mode 100644
index 0000000..1b31671
--- /dev/null
+++ b/notebooks/04-vqvae.ipynb
@@ -0,0 +1,233 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "136a80f5-10e1-40c4-973a-a7eb7939bb1f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ }
+ ],
+ "source": [
+ "import os\n",
+ "os.environ['CUDA_VISIBLE_DEVICE'] = ''\n",
+ "import random\n",
+ "\n",
+ "%matplotlib inline\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "import numpy as np\n",
+ "from omegaconf import OmegaConf\n",
+ "from hydra import compose, initialize\n",
+ "from omegaconf import OmegaConf\n",
+ "from hydra.utils import instantiate\n",
+ "from torchinfo import summary\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "from importlib.util import find_spec\n",
+ "if find_spec(\"text_recognizer\") is None:\n",
+ " import sys\n",
+ " sys.path.append('..')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "1a0fb9ca-1886-4fd4-839f-dc111a450cfd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "path = \"../training/conf/network/vqvae.yaml\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "0182a614-5781-44a6-b659-008e7c584fa7",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "encoder:\n",
+ " _target_: text_recognizer.networks.vqvae.encoder.Encoder\n",
+ " in_channels: 1\n",
+ " hidden_dim: 32\n",
+ " channels_multipliers:\n",
+ " - 1\n",
+ " - 2\n",
+ " - 4\n",
+ " dropout_rate: 0.0\n",
+ " activation: mish\n",
+ " use_norm: true\n",
+ " num_residuals: 4\n",
+ " residual_channels: 32\n",
+ "decoder:\n",
+ " _target_: text_recognizer.networks.vqvae.decoder.Decoder\n",
+ " out_channels: 1\n",
+ " hidden_dim: 32\n",
+ " channels_multipliers:\n",
+ " - 4\n",
+ " - 2\n",
+ " - 1\n",
+ " dropout_rate: 0.0\n",
+ " activation: mish\n",
+ " use_norm: true\n",
+ " num_residuals: 4\n",
+ " residual_channels: 32\n",
+ "_target_: text_recognizer.networks.vqvae.vqvae.VQVAE\n",
+ "hidden_dim: 128\n",
+ "embedding_dim: 32\n",
+ "num_embeddings: 8192\n",
+ "decay: 0.99\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/aktersnurra/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/hydra/_internal/defaults_list.py:251: UserWarning: In 'vqvae': Defaults list is missing `_self_`. See https://hydra.cc/docs/upgrades/1.0_to_1.1/default_composition_order for more information\n",
+ " warnings.warn(msg, UserWarning)\n"
+ ]
+ }
+ ],
+ "source": [
+ "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n",
+ " cfg = compose(config_name=\"vqvae\")\n",
+ " print(OmegaConf.to_yaml(cfg))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "a500f94c-7dae-477e-a3fb-2a2d62ee7b72",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "net = instantiate(cfg)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "7f3b3559-5e23-485e-bf57-9405568a1fbf",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "====================================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "====================================================================================================\n",
+ "VQVAE -- --\n",
+ "├─Encoder: 1-1 [1, 128, 72, 80] --\n",
+ "│ └─Sequential: 2-1 [1, 128, 72, 80] --\n",
+ "│ │ └─Conv2d: 3-1 [1, 32, 576, 640] 320\n",
+ "│ │ └─Normalize: 3-2 [1, 32, 576, 640] 64\n",
+ "│ │ └─Mish: 3-3 [1, 32, 576, 640] --\n",
+ "│ │ └─Mish: 3-4 [1, 32, 576, 640] --\n",
+ "│ │ └─Mish: 3-5 [1, 32, 576, 640] --\n",
+ "│ │ └─Conv2d: 3-6 [1, 32, 288, 320] 16,416\n",
+ "│ │ └─Normalize: 3-7 [1, 32, 288, 320] 64\n",
+ "│ │ └─Mish: 3-8 [1, 32, 288, 320] --\n",
+ "│ │ └─Mish: 3-9 [1, 32, 288, 320] --\n",
+ "│ │ └─Mish: 3-10 [1, 32, 288, 320] --\n",
+ "│ │ └─Conv2d: 3-11 [1, 64, 144, 160] 32,832\n",
+ "│ │ └─Normalize: 3-12 [1, 64, 144, 160] 128\n",
+ "│ │ └─Mish: 3-13 [1, 64, 144, 160] --\n",
+ "│ │ └─Mish: 3-14 [1, 64, 144, 160] --\n",
+ "│ │ └─Mish: 3-15 [1, 64, 144, 160] --\n",
+ "│ │ └─Conv2d: 3-16 [1, 128, 72, 80] 131,200\n",
+ "│ │ └─Residual: 3-17 [1, 128, 72, 80] 41,280\n",
+ "│ │ └─Residual: 3-18 [1, 128, 72, 80] 41,280\n",
+ "│ │ └─Residual: 3-19 [1, 128, 72, 80] 41,280\n",
+ "│ │ └─Residual: 3-20 [1, 128, 72, 80] 41,280\n",
+ "├─Conv2d: 1-2 [1, 32, 72, 80] 4,128\n",
+ "├─VectorQuantizer: 1-3 [1, 32, 72, 80] --\n",
+ "├─Conv2d: 1-4 [1, 128, 72, 80] 4,224\n",
+ "├─Decoder: 1-5 [1, 1, 576, 640] --\n",
+ "│ └─Sequential: 2-2 [1, 1, 576, 640] --\n",
+ "│ │ └─Residual: 3-21 [1, 128, 72, 80] 41,280\n",
+ "│ │ └─Residual: 3-22 [1, 128, 72, 80] 41,280\n",
+ "│ │ └─Residual: 3-23 [1, 128, 72, 80] 41,280\n",
+ "│ │ └─Residual: 3-24 [1, 128, 72, 80] 41,280\n",
+ "│ │ └─Normalize: 3-25 [1, 128, 72, 80] 256\n",
+ "│ │ └─Mish: 3-26 [1, 128, 72, 80] --\n",
+ "│ │ └─Mish: 3-27 [1, 128, 72, 80] --\n",
+ "│ │ └─Mish: 3-28 [1, 128, 72, 80] --\n",
+ "│ │ └─ConvTranspose2d: 3-29 [1, 64, 144, 160] 131,136\n",
+ "│ │ └─Normalize: 3-30 [1, 64, 144, 160] 128\n",
+ "│ │ └─Mish: 3-31 [1, 64, 144, 160] --\n",
+ "│ │ └─Mish: 3-32 [1, 64, 144, 160] --\n",
+ "│ │ └─Mish: 3-33 [1, 64, 144, 160] --\n",
+ "│ │ └─ConvTranspose2d: 3-34 [1, 32, 288, 320] 32,800\n",
+ "│ │ └─Normalize: 3-35 [1, 32, 288, 320] 64\n",
+ "│ │ └─Mish: 3-36 [1, 32, 288, 320] --\n",
+ "│ │ └─Mish: 3-37 [1, 32, 288, 320] --\n",
+ "│ │ └─Mish: 3-38 [1, 32, 288, 320] --\n",
+ "│ │ └─ConvTranspose2d: 3-39 [1, 32, 576, 640] 16,416\n",
+ "│ │ └─Normalize: 3-40 [1, 32, 576, 640] 64\n",
+ "│ │ └─Conv2d: 3-41 [1, 1, 576, 640] 289\n",
+ "====================================================================================================\n",
+ "Total params: 700,769\n",
+ "Trainable params: 700,769\n",
+ "Non-trainable params: 0\n",
+ "Total mult-adds (G): 17.28\n",
+ "====================================================================================================\n",
+ "Input size (MB): 1.47\n",
+ "Forward/backward pass size (MB): 659.13\n",
+ "Params size (MB): 2.80\n",
+ "Estimated Total Size (MB): 663.41\n",
+ "===================================================================================================="
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "summary(net, (1, 1, 576, 640), device=\"cpu\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9f880b03-d641-4640-acd3-aa5666ca5184",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}