{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from PIL import Image\n", "import torch.nn.functional as F\n", "import torch\n", "from torch import nn\n", "from torchsummary import summary\n", "from importlib.util import find_spec\n", "if find_spec(\"text_recognizer\") is None:\n", " import sys\n", " sys.path.append('..')\n", "\n", "from text_recognizer.networks.transformer.vit import ViT\n", "from text_recognizer.networks.transformer.transformer import Transformer\n", "from text_recognizer.networks.transformer.layers import Decoder\n", "from text_recognizer.networks.transformer.nystromer.nystromer import Nystromer" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet" ] }, { "cell_type": "code", "execution_count": 63, "metadata": { "scrolled": false }, "outputs": [], "source": [ "en = EfficientNet(\"b0\")" ] }, { "cell_type": "code", "execution_count": 64, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", "├─Sequential: 1-1 [-1, 32, 111, 111] --\n", "| └─Conv2d: 2-1 [-1, 32, 111, 111] 288\n", "| └─BatchNorm2d: 2-2 [-1, 32, 111, 111] 64\n", "| └─Mish: 2-3 [-1, 32, 111, 111] --\n", "├─ModuleList: 1 [] --\n", "| └─MBConvBlock: 2-4 [-1, 16, 111, 111] --\n", "| | └─Sequential: 3-1 [-1, 32, 111, 111] 352\n", "| | └─Sequential: 3-2 [-1, 32, 111, 111] 552\n", "| | └─Sequential: 3-3 [-1, 16, 111, 111] 544\n", "| └─MBConvBlock: 2-5 [-1, 24, 55, 55] --\n", "| | └─Sequential: 3-4 [-1, 96, 111, 111] 1,728\n", "| | └─Sequential: 3-5 [-1, 96, 55, 55] 1,056\n", "| | └─Sequential: 3-6 [-1, 96, 55, 55] 4,728\n", "| | └─Sequential: 3-7 [-1, 24, 55, 55] 2,352\n", "| └─MBConvBlock: 2-6 [-1, 24, 55, 55] --\n", "| | └─Sequential: 3-8 [-1, 144, 55, 55] 3,744\n", "| | └─Sequential: 3-9 [-1, 144, 55, 55] 1,584\n", "| | └─Sequential: 3-10 [-1, 144, 55, 55] 10,548\n", "| | └─Sequential: 3-11 [-1, 24, 55, 55] 3,504\n", "| └─MBConvBlock: 2-7 [-1, 40, 27, 27] --\n", "| | └─Sequential: 3-12 [-1, 144, 55, 55] 3,744\n", "| | └─Sequential: 3-13 [-1, 144, 27, 27] 3,888\n", "| | └─Sequential: 3-14 [-1, 144, 27, 27] 10,548\n", "| | └─Sequential: 3-15 [-1, 40, 27, 27] 5,840\n", "| └─MBConvBlock: 2-8 [-1, 40, 27, 27] --\n", "| | └─Sequential: 3-16 [-1, 240, 27, 27] 10,080\n", "| | └─Sequential: 3-17 [-1, 240, 27, 27] 6,480\n", "| | └─Sequential: 3-18 [-1, 240, 27, 27] 29,100\n", "| | └─Sequential: 3-19 [-1, 40, 27, 27] 9,680\n", "| └─MBConvBlock: 2-9 [-1, 80, 13, 13] --\n", "| | └─Sequential: 3-20 [-1, 240, 27, 27] 10,080\n", "| | └─Sequential: 3-21 [-1, 240, 13, 13] 2,640\n", "| | └─Sequential: 3-22 [-1, 240, 13, 13] 29,100\n", "| | └─Sequential: 3-23 [-1, 80, 13, 13] 19,360\n", "| └─MBConvBlock: 2-10 [-1, 80, 13, 13] --\n", "| | └─Sequential: 3-24 [-1, 480, 13, 13] 39,360\n", "| | └─Sequential: 3-25 [-1, 480, 13, 13] 5,280\n", "| | └─Sequential: 3-26 [-1, 480, 13, 13] 115,800\n", "| | └─Sequential: 3-27 [-1, 80, 13, 13] 38,560\n", "| └─MBConvBlock: 2-11 [-1, 80, 13, 13] --\n", "| | └─Sequential: 3-28 [-1, 480, 13, 13] 39,360\n", "| | └─Sequential: 3-29 [-1, 480, 13, 13] 5,280\n", "| | └─Sequential: 3-30 [-1, 480, 13, 13] 115,800\n", "| | └─Sequential: 3-31 [-1, 80, 13, 13] 38,560\n", "| └─MBConvBlock: 2-12 [-1, 112, 13, 13] --\n", "| | └─Sequential: 3-32 [-1, 480, 13, 13] 39,360\n", "| | └─Sequential: 3-33 [-1, 480, 13, 13] 12,960\n", "| | └─Sequential: 3-34 [-1, 480, 13, 13] 115,800\n", "| | └─Sequential: 3-35 [-1, 112, 13, 13] 53,984\n", "| └─MBConvBlock: 2-13 [-1, 112, 13, 13] --\n", "| | └─Sequential: 3-36 [-1, 672, 13, 13] 76,608\n", "| | └─Sequential: 3-37 [-1, 672, 13, 13] 18,144\n", "| | └─Sequential: 3-38 [-1, 672, 13, 13] 226,632\n", "| | └─Sequential: 3-39 [-1, 112, 13, 13] 75,488\n", "| └─MBConvBlock: 2-14 [-1, 112, 13, 13] --\n", "| | └─Sequential: 3-40 [-1, 672, 13, 13] 76,608\n", "| | └─Sequential: 3-41 [-1, 672, 13, 13] 18,144\n", "| | └─Sequential: 3-42 [-1, 672, 13, 13] 226,632\n", "| | └─Sequential: 3-43 [-1, 112, 13, 13] 75,488\n", "| └─MBConvBlock: 2-15 [-1, 192, 6, 6] --\n", "| | └─Sequential: 3-44 [-1, 672, 13, 13] 76,608\n", "| | └─Sequential: 3-45 [-1, 672, 6, 6] 18,144\n", "| | └─Sequential: 3-46 [-1, 672, 6, 6] 226,632\n", "| | └─Sequential: 3-47 [-1, 192, 6, 6] 129,408\n", "| └─MBConvBlock: 2-16 [-1, 192, 6, 6] --\n", "| | └─Sequential: 3-48 [-1, 1152, 6, 6] 223,488\n", "| | └─Sequential: 3-49 [-1, 1152, 6, 6] 31,104\n", "| | └─Sequential: 3-50 [-1, 1152, 6, 6] 664,992\n", "| | └─Sequential: 3-51 [-1, 192, 6, 6] 221,568\n", "| └─MBConvBlock: 2-17 [-1, 192, 6, 6] --\n", "| | └─Sequential: 3-52 [-1, 1152, 6, 6] 223,488\n", "| | └─Sequential: 3-53 [-1, 1152, 6, 6] 31,104\n", "| | └─Sequential: 3-54 [-1, 1152, 6, 6] 664,992\n", "| | └─Sequential: 3-55 [-1, 192, 6, 6] 221,568\n", "| └─MBConvBlock: 2-18 [-1, 192, 6, 6] --\n", "| | └─Sequential: 3-56 [-1, 1152, 6, 6] 223,488\n", "| | └─Sequential: 3-57 [-1, 1152, 6, 6] 31,104\n", "| | └─Sequential: 3-58 [-1, 1152, 6, 6] 664,992\n", "| | └─Sequential: 3-59 [-1, 192, 6, 6] 221,568\n", "| └─MBConvBlock: 2-19 [-1, 320, 6, 6] --\n", "| | └─Sequential: 3-60 [-1, 1152, 6, 6] 223,488\n", "| | └─Sequential: 3-61 [-1, 1152, 6, 6] 12,672\n", "| | └─Sequential: 3-62 [-1, 1152, 6, 6] 664,992\n", "| | └─Sequential: 3-63 [-1, 320, 6, 6] 369,280\n", "├─Sequential: 1-2 [-1, 1280, 6, 6] --\n", "| └─Conv2d: 2-20 [-1, 1280, 6, 6] 409,600\n", "| └─BatchNorm2d: 2-21 [-1, 1280, 6, 6] 2,560\n", "==========================================================================================\n", "Total params: 7,142,272\n", "Trainable params: 7,142,272\n", "Non-trainable params: 0\n", "Total mult-adds (M): 657.05\n", "==========================================================================================\n", "Input size (MB): 0.19\n", "Forward/backward pass size (MB): 115.14\n", "Params size (MB): 27.25\n", "Estimated Total Size (MB): 142.58\n", "==========================================================================================\n" ] } ], "source": [ "summary(en, (1, 224, 224));" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(2, 2)" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(2,) * 2" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/aktersnurra/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:115.)\n", " return torch._C._cuda_getDeviceCount() > 0\n" ] }, { "data": { "text/plain": [ "False" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.cuda.is_available()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "decoder = Decoder(dim=128, depth=2, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "decoder.cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "transformer_decoder = Transformer(num_tokens=1000, max_seq_len=690, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "transformer_decoder.cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "efficient_transformer = Nystromer(\n", " dim = 64,\n", " depth = 4,\n", " num_heads = 8,\n", " num_landmarks = 64\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "v = ViT(\n", " dim = 64,\n", " image_size = (576, 640),\n", " patch_size = (32, 32),\n", " transformer = efficient_transformer\n", ").cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = torch.randn(4, 1, 576, 640).cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "o = v(t)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "caption = torch.randint(0, 90, (16, 690)).cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "o.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "caption.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "o = torch.randn(16, 20 * 18, 128).cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "caption = torch.randint(0, 1000, (16, 200)).cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "transformer_decoder(caption, context = o).shape # (1, 1024, 20000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "en = EfficientNet()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "en.cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "summary(en, (1, 576, 640))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "type(efficient_transformer)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "efficient_transformer = efficient_transformer(num_landmarks=256)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "efficient_transformer()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from omegaconf import OmegaConf" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = \"../training/configs/vqvae.yaml\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "conf = OmegaConf.load(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(OmegaConf.to_yaml(conf))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks import VQVAE" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae = VQVAE(**conf.network.args)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "datum = torch.randn([2, 1, 576, 640])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae.encoder(datum)[0].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae(datum)[0].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "datum = torch.randn([2, 1, 576, 640])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trg = torch.randint(0, 1000, [2, 682])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trg.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "datum = torch.randn([2, 1, 224, 224])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "en(t).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = \"../training/configs/cnn_transformer.yaml\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "conf = OmegaConf.load(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(OmegaConf.to_yaml(conf))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks.cnn_transformer import CNNTransformer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = CNNTransformer(input_shape=(1, 576, 640), output_shape=(682, 1), **conf.network.args)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t.encode(datum).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trg.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t(datum, trg).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "b, n = 16, 128\n", "device = \"cpu\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = lambda: torch.ones((b, n), device=device).bool()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x().shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.ones((b, n), device=device).bool().shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(1, 1, 576, 640)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "576 // 32" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "640 // 32" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "18 * 20" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(1, 1, 144, 160)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from einops import rearrange" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "patch_size=16\n", "p = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "p.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 4 }