summaryrefslogtreecommitdiff
path: root/notebooks/00-scratch-pad.ipynb
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-04 23:24:00 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-04 23:24:00 +0200
commit4defc734b681071e19dd86404abd416d24330b9a (patch)
tree2447a7bc3fada64d1b45ac73346f816f9e90849c /notebooks/00-scratch-pad.ipynb
parent53450493e0a13d835fd1d2457c49a9d60bee0e18 (diff)
Bug fix
Diffstat (limited to 'notebooks/00-scratch-pad.ipynb')
-rw-r--r--notebooks/00-scratch-pad.ipynb598
1 files changed, 598 insertions, 0 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb
new file mode 100644
index 0000000..d50fd59
--- /dev/null
+++ b/notebooks/00-scratch-pad.ipynb
@@ -0,0 +1,598 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "%matplotlib inline\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "from PIL import Image\n",
+ "import torch.nn.functional as F\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "from torchsummary import summary\n",
+ "from importlib.util import find_spec\n",
+ "if find_spec(\"text_recognizer\") is None:\n",
+ " import sys\n",
+ " sys.path.append('..')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "ModuleList(\n",
+ " (0): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (1): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (2): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (3): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (4): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (5): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (6): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (7): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (8): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ " (9): ModuleList(\n",
+ " (0): Linear(in_features=10, out_features=10, bias=True)\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nn.ModuleList([nn.ModuleList([nn.Linear(10, 10)]) for _ in range(10)])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from omegaconf import OmegaConf"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "path = \"../training/configs/vqvae.yaml\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "conf = OmegaConf.load(path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(OmegaConf.to_yaml(conf))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks import VQVAE"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vae = VQVAE(**conf.network.args)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vae"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "datum = torch.randn([2, 1, 576, 640])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vae.encoder(datum)[0].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vae(datum)[0].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.backbones.efficientnet import EfficientNet"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "en = EfficientNet()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "datum = torch.randn([2, 1, 576, 640])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trg = torch.randint(0, 1000, [2, 682])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trg.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "datum = torch.randn([2, 1, 224, 224])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "en(datum).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "path = \"../training/configs/cnn_transformer.yaml\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "conf = OmegaConf.load(path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "seed: 4711\n",
+ "network:\n",
+ " desc: Configuration of the PyTorch neural network.\n",
+ " type: CNNTransformer\n",
+ " args:\n",
+ " encoder:\n",
+ " type: EfficientNet\n",
+ " args: null\n",
+ " num_decoder_layers: 4\n",
+ " hidden_dim: 256\n",
+ " num_heads: 4\n",
+ " expansion_dim: 1024\n",
+ " dropout_rate: 0.1\n",
+ " transformer_activation: glu\n",
+ "model:\n",
+ " desc: Configuration of the PyTorch Lightning model.\n",
+ " type: LitTransformerModel\n",
+ " args:\n",
+ " optimizer:\n",
+ " type: MADGRAD\n",
+ " args:\n",
+ " lr: 0.001\n",
+ " momentum: 0.9\n",
+ " weight_decay: 0\n",
+ " eps: 1.0e-06\n",
+ " lr_scheduler:\n",
+ " type: OneCycleLR\n",
+ " args:\n",
+ " interval: step\n",
+ " max_lr: 0.001\n",
+ " three_phase: true\n",
+ " epochs: 512\n",
+ " steps_per_epoch: 1246\n",
+ " criterion:\n",
+ " type: CrossEntropyLoss\n",
+ " args:\n",
+ " weight: None\n",
+ " ignore_index: -100\n",
+ " reduction: mean\n",
+ " monitor: val_loss\n",
+ " mapping: sentence_piece\n",
+ "data:\n",
+ " desc: Configuration of the training/test data.\n",
+ " type: IAMExtendedParagraphs\n",
+ " args:\n",
+ " batch_size: 16\n",
+ " num_workers: 12\n",
+ " train_fraction: 0.8\n",
+ " augment: true\n",
+ "callbacks:\n",
+ "- type: ModelCheckpoint\n",
+ " args:\n",
+ " monitor: val_loss\n",
+ " mode: min\n",
+ " save_last: true\n",
+ "- type: StochasticWeightAveraging\n",
+ " args:\n",
+ " swa_epoch_start: 0.8\n",
+ " swa_lrs: 0.05\n",
+ " annealing_epochs: 10\n",
+ " annealing_strategy: cos\n",
+ " device: null\n",
+ "- type: LearningRateMonitor\n",
+ " args:\n",
+ " logging_interval: step\n",
+ "- type: EarlyStopping\n",
+ " args:\n",
+ " monitor: val_loss\n",
+ " mode: min\n",
+ " patience: 10\n",
+ "trainer:\n",
+ " desc: Configuration of the PyTorch Lightning Trainer.\n",
+ " args:\n",
+ " stochastic_weight_avg: true\n",
+ " auto_scale_batch_size: binsearch\n",
+ " gradient_clip_val: 0\n",
+ " fast_dev_run: false\n",
+ " gpus: 1\n",
+ " precision: 16\n",
+ " max_epochs: 512\n",
+ " terminate_on_nan: true\n",
+ " weights_summary: true\n",
+ "load_checkpoint: null\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(OmegaConf.to_yaml(conf))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.cnn_transformer import CNNTransformer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = CNNTransformer(input_shape=(1, 576, 640), output_shape=(682, 1), **conf.network.args)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t.encode(datum).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trg.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 682, 1004])"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t(datum, trg).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "b, n = 16, 128\n",
+ "device = \"cpu\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = lambda: torch.ones((b, n), device=device).bool()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([16, 128])"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x().shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([16, 128])"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.ones((b, n), device=device).bool().shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = torch.randn(1, 1, 576, 640)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "144"
+ ]
+ },
+ "execution_count": 30,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "576 // 4"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "160"
+ ]
+ },
+ "execution_count": 29,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "640 // 4"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = torch.randn(1, 1, 144, 160)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from einops import rearrange"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "patch_size=16\n",
+ "p = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 1440, 256])"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "p.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.1"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}