{ "cells": [ { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from PIL import Image\n", "import torch.nn.functional as F\n", "import torch\n", "from torch import nn\n", "from torchsummary import summary\n", "from importlib.util import find_spec\n", "if find_spec(\"text_recognizer\") is None:\n", " import sys\n", " sys.path.append('..')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "from omegaconf import OmegaConf" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "path = \"../training/configs/vqvae.yaml\"" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "conf = OmegaConf.load(path)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "seed: 4711\n", "network:\n", " desc: Configuration of the PyTorch neural network.\n", " type: VQVAE\n", " args:\n", " in_channels: 1\n", " channels:\n", " - 32\n", " - 64\n", " - 96\n", " - 96\n", " - 128\n", " kernel_sizes:\n", " - 4\n", " - 4\n", " - 4\n", " - 4\n", " - 4\n", " strides:\n", " - 2\n", " - 2\n", " - 2\n", " - 2\n", " - 2\n", " num_residual_layers: 2\n", " embedding_dim: 128\n", " num_embeddings: 1024\n", " upsampling: null\n", " beta: 0.25\n", " activation: leaky_relu\n", " dropout_rate: 0.1\n", "model:\n", " desc: Configuration of the PyTorch Lightning model.\n", " type: LitVQVAEModel\n", " args:\n", " optimizer:\n", " type: MADGRAD\n", " args:\n", " lr: 0.001\n", " momentum: 0.9\n", " weight_decay: 0\n", " eps: 1.0e-06\n", " lr_scheduler:\n", " type: OneCycleLR\n", " args:\n", " interval: step\n", " max_lr: 0.001\n", " three_phase: true\n", " epochs: 1024\n", " steps_per_epoch: 317\n", " criterion:\n", " type: MSELoss\n", " args:\n", " reduction: mean\n", " monitor: val_loss\n", " mapping: sentence_piece\n", "data:\n", " desc: Configuration of the training/test data.\n", " type: IAMExtendedParagraphs\n", " args:\n", " batch_size: 64\n", " num_workers: 12\n", " train_fraction: 0.8\n", " augment: true\n", "callbacks:\n", "- type: ModelCheckpoint\n", " args:\n", " monitor: val_loss\n", " mode: min\n", " save_last: true\n", "- type: LearningRateMonitor\n", " args:\n", " logging_interval: step\n", "trainer:\n", " desc: Configuration of the PyTorch Lightning Trainer.\n", " args:\n", " stochastic_weight_avg: false\n", " auto_scale_batch_size: binsearch\n", " gradient_clip_val: 0\n", " fast_dev_run: false\n", " gpus: 1\n", " precision: 16\n", " max_epochs: 1024\n", " terminate_on_nan: true\n", " weights_summary: full\n", "load_checkpoint: null\n", "\n" ] } ], "source": [ "print(OmegaConf.to_yaml(conf))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks import VQVAE" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "vae = VQVAE(**conf.network.args)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "VQVAE(\n", " (encoder): Encoder(\n", " (encoder): Sequential(\n", " (0): Sequential(\n", " (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", " (1): Dropout(p=0.1, inplace=False)\n", " (2): Sequential(\n", " (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", " (3): Dropout(p=0.1, inplace=False)\n", " (4): Sequential(\n", " (0): Conv2d(64, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", " (5): Dropout(p=0.1, inplace=False)\n", " (6): Sequential(\n", " (0): Conv2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", " (7): Dropout(p=0.1, inplace=False)\n", " (8): Sequential(\n", " (0): Conv2d(96, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", " (9): Dropout(p=0.1, inplace=False)\n", " (10): _ResidualBlock(\n", " (block): Sequential(\n", " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (3): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (11): _ResidualBlock(\n", " (block): Sequential(\n", " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (3): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (12): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (vector_quantizer): VectorQuantizer(\n", " (embedding): Embedding(1024, 128)\n", " )\n", " )\n", " (decoder): Decoder(\n", " (res_block): Sequential(\n", " (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n", " (1): _ResidualBlock(\n", " (block): Sequential(\n", " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (3): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (2): _ResidualBlock(\n", " (block): Sequential(\n", " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (3): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (upsampling_block): Sequential(\n", " (0): Sequential(\n", " (0): ConvTranspose2d(128, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", " (1): Dropout(p=0.1, inplace=False)\n", " (2): Sequential(\n", " (0): ConvTranspose2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", " (3): Dropout(p=0.1, inplace=False)\n", " (4): Sequential(\n", " (0): ConvTranspose2d(96, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", " (5): Dropout(p=0.1, inplace=False)\n", " (6): Sequential(\n", " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", " (7): Dropout(p=0.1, inplace=False)\n", " (8): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (9): Tanh()\n", " )\n", " (decoder): Sequential(\n", " (0): Sequential(\n", " (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n", " (1): _ResidualBlock(\n", " (block): Sequential(\n", " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (3): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (2): _ResidualBlock(\n", " (block): Sequential(\n", " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (3): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (1): Sequential(\n", " (0): Sequential(\n", " (0): ConvTranspose2d(128, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", " (1): Dropout(p=0.1, inplace=False)\n", " (2): Sequential(\n", " (0): ConvTranspose2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", " (3): Dropout(p=0.1, inplace=False)\n", " (4): Sequential(\n", " (0): ConvTranspose2d(96, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", " (5): Dropout(p=0.1, inplace=False)\n", " (6): Sequential(\n", " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", " (7): Dropout(p=0.1, inplace=False)\n", " (8): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (9): Tanh()\n", " )\n", " )\n", " )\n", ")" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vae" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "datum = torch.randn([2, 1, 576, 640])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "proj = nn.Conv2d(1, 32, kernel_size=16, stride=16)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "x = proj(datum)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 32, 36, 40])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.shape" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "xx = x.flatten(2)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 32, 1440])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xx.shape" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "xxx = xx.transpose(1,2)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 1440, 32])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xxx.shape" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "from einops import rearrange" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "xxxx = rearrange(x, \"b c h w -> b ( h w ) c\")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 1440, 32])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xxxx.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ " B, N, C = x.shape\n", " H, W = size\n", " assert N == 1 + H * W\n", "\n", " # Extract CLS token and image tokens.\n", " cls_token, img_tokens = x[:, :1], x[:, 1:] # Shape: [B, 1, C], [B, H*W, C].\n", " \n", " # Depthwise convolution.\n", " feat = img_tokens.transpose(1, 2).view(B, C, H, W)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 32, 36, 40])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xxx.transpose(1, 2).view(2, 32, 36, 40).shape" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "72.0" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "576 / 8" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "80.0" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "640 / 8" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 1, 576, 640])" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datum.shape" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 128, 18, 20])" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vae.encoder(datum)[0].shape" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 1, 576, 640])" ] }, "execution_count": 87, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vae(datum)[0].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.1" } }, "nbformat": 4, "nbformat_minor": 4 }