diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-04 23:24:00 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-04 23:24:00 +0200 |
commit | 4defc734b681071e19dd86404abd416d24330b9a (patch) | |
tree | 2447a7bc3fada64d1b45ac73346f816f9e90849c /notebooks/00-scratch-pad.ipynb | |
parent | 53450493e0a13d835fd1d2457c49a9d60bee0e18 (diff) |
Bug fix
Diffstat (limited to 'notebooks/00-scratch-pad.ipynb')
-rw-r--r-- | notebooks/00-scratch-pad.ipynb | 598 |
1 files changed, 598 insertions, 0 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb new file mode 100644 index 0000000..d50fd59 --- /dev/null +++ b/notebooks/00-scratch-pad.ipynb @@ -0,0 +1,598 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from PIL import Image\n", + "import torch.nn.functional as F\n", + "import torch\n", + "from torch import nn\n", + "from torchsummary import summary\n", + "from importlib.util import find_spec\n", + "if find_spec(\"text_recognizer\") is None:\n", + " import sys\n", + " sys.path.append('..')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ModuleList(\n", + " (0): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (1): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (2): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (3): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (4): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (5): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (6): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (7): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (8): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (9): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nn.ModuleList([nn.ModuleList([nn.Linear(10, 10)]) for _ in range(10)])" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from omegaconf import OmegaConf" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "path = \"../training/configs/vqvae.yaml\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "conf = OmegaConf.load(path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(OmegaConf.to_yaml(conf))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks import VQVAE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vae = VQVAE(**conf.network.args)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vae" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "datum = torch.randn([2, 1, 576, 640])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vae.encoder(datum)[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vae(datum)[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.backbones.efficientnet import EfficientNet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "en = EfficientNet()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "datum = torch.randn([2, 1, 576, 640])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "trg = torch.randint(0, 1000, [2, 682])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trg.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "datum = torch.randn([2, 1, 224, 224])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "en(datum).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "path = \"../training/configs/cnn_transformer.yaml\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "conf = OmegaConf.load(path)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "seed: 4711\n", + "network:\n", + " desc: Configuration of the PyTorch neural network.\n", + " type: CNNTransformer\n", + " args:\n", + " encoder:\n", + " type: EfficientNet\n", + " args: null\n", + " num_decoder_layers: 4\n", + " hidden_dim: 256\n", + " num_heads: 4\n", + " expansion_dim: 1024\n", + " dropout_rate: 0.1\n", + " transformer_activation: glu\n", + "model:\n", + " desc: Configuration of the PyTorch Lightning model.\n", + " type: LitTransformerModel\n", + " args:\n", + " optimizer:\n", + " type: MADGRAD\n", + " args:\n", + " lr: 0.001\n", + " momentum: 0.9\n", + " weight_decay: 0\n", + " eps: 1.0e-06\n", + " lr_scheduler:\n", + " type: OneCycleLR\n", + " args:\n", + " interval: step\n", + " max_lr: 0.001\n", + " three_phase: true\n", + " epochs: 512\n", + " steps_per_epoch: 1246\n", + " criterion:\n", + " type: CrossEntropyLoss\n", + " args:\n", + " weight: None\n", + " ignore_index: -100\n", + " reduction: mean\n", + " monitor: val_loss\n", + " mapping: sentence_piece\n", + "data:\n", + " desc: Configuration of the training/test data.\n", + " type: IAMExtendedParagraphs\n", + " args:\n", + " batch_size: 16\n", + " num_workers: 12\n", + " train_fraction: 0.8\n", + " augment: true\n", + "callbacks:\n", + "- type: ModelCheckpoint\n", + " args:\n", + " monitor: val_loss\n", + " mode: min\n", + " save_last: true\n", + "- type: StochasticWeightAveraging\n", + " args:\n", + " swa_epoch_start: 0.8\n", + " swa_lrs: 0.05\n", + " annealing_epochs: 10\n", + " annealing_strategy: cos\n", + " device: null\n", + "- type: LearningRateMonitor\n", + " args:\n", + " logging_interval: step\n", + "- type: EarlyStopping\n", + " args:\n", + " monitor: val_loss\n", + " mode: min\n", + " patience: 10\n", + "trainer:\n", + " desc: Configuration of the PyTorch Lightning Trainer.\n", + " args:\n", + " stochastic_weight_avg: true\n", + " auto_scale_batch_size: binsearch\n", + " gradient_clip_val: 0\n", + " fast_dev_run: false\n", + " gpus: 1\n", + " precision: 16\n", + " max_epochs: 512\n", + " terminate_on_nan: true\n", + " weights_summary: true\n", + "load_checkpoint: null\n", + "\n" + ] + } + ], + "source": [ + "print(OmegaConf.to_yaml(conf))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.cnn_transformer import CNNTransformer" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "t = CNNTransformer(input_shape=(1, 576, 640), output_shape=(682, 1), **conf.network.args)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "t.encode(datum).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trg.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 682, 1004])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t(datum, trg).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "b, n = 16, 128\n", + "device = \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "x = lambda: torch.ones((b, n), device=device).bool()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([16, 128])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x().shape" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([16, 128])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.ones((b, n), device=device).bool().shape" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.randn(1, 1, 576, 640)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "144" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "576 // 4" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "160" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "640 // 4" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.randn(1, 1, 144, 160)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from einops import rearrange" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "patch_size=16\n", + "p = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 1440, 256])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "p.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} |