{ "cells": [ { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from PIL import Image\n", "import torch.nn.functional as F\n", "import torch\n", "from torch import nn\n", "from torchsummary import summary\n", "from importlib.util import find_spec\n", "if find_spec(\"text_recognizer\") is None:\n", " import sys\n", " sys.path.append('..')\n", "\n", "from text_recognizer.networks.transformer.vit import ViT\n", "from text_recognizer.networks.transformer.transformer import Transformer\n", "from text_recognizer.networks.transformer.layers import Decoder\n", "from text_recognizer.networks.transformer.nystromer.nystromer import Nystromer" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "import attr" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "@attr.s\n", "class B(nn.Module):\n", " input_dim = attr.ib()\n", " hidden = attr.ib()\n", " xx = attr.ib(init=False, default=\"hek\")\n", " \n", " def __attrs_post_init__(self):\n", " super().__init__()\n", " self.fc = nn.Linear(self.input_dim, self.hidden)\n", " self.xx = \"da\"\n", " \n", " def forward(self, x):\n", " return self.fc(x)" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "def f(x):\n", " return 2\n", "\n", "@attr.s(auto_attribs=True)\n", "class T(B):\n", " \n", " h: Path = attr.ib(converter=Path)\n", " p: int = attr.ib(init=False, default=f(3))" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "__init__() missing 1 required positional argument: 'hidden'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mT\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_dim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m16\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"hej\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mTypeError\u001b[0m: __init__() missing 1 required positional argument: 'hidden'" ] } ], "source": [ "t = T(input_dim=16, h=\"hej\")" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'da'" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.xx" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.p" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "16" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.input_dim" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "x = torch.rand(16, 16)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([16, 16])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.shape" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "T(input_dim=16, hidden=24, h=PosixPath('hej'))" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.cuda()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "x = x.cuda()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "tensor([[ 3.6047e-01, 1.0200e+00, 3.6786e-01, 1.6077e-01, 3.9281e-02,\n", " 3.2830e-01, 1.3433e-01, -9.0334e-02, -3.8712e-01, 8.1547e-01,\n", " -5.4483e-01, -9.7471e-01, 3.3706e-01, -9.5283e-01, -1.6271e-01,\n", " 3.8504e-01, -5.0106e-01, -4.8638e-01, 3.7033e-01, -4.9557e-01,\n", " 2.6555e-01, 5.1245e-01, 6.6751e-01, -2.6291e-01],\n", " [ 1.3811e-01, 7.4522e-01, 4.9935e-01, 3.3878e-01, 1.8501e-01,\n", " 2.2269e-02, -2.0328e-01, 1.4629e-01, -2.2957e-01, 4.1197e-01,\n", " -1.9555e-01, -4.7609e-01, 9.0206e-02, -8.8568e-01, -2.1618e-01,\n", " 2.8882e-01, -5.4335e-01, -6.6301e-01, 4.9990e-01, -4.0144e-01,\n", " 3.6403e-01, 5.3901e-01, 8.6665e-01, -7.8312e-02],\n", " [ 1.6493e-02, 4.6157e-01, 2.9500e-02, 2.4190e-01, 6.5753e-01,\n", " 4.3770e-02, -5.3773e-02, 1.8183e-01, -2.5983e-02, 4.1634e-01,\n", " -3.5218e-01, -5.6129e-01, 4.1452e-01, -1.2265e+00, -5.8544e-01,\n", " 3.6382e-01, -6.4090e-01, -5.8679e-01, 4.3489e-02, -1.1233e-01,\n", " 3.1175e-01, 4.2857e-01, 1.6501e-01, -2.4118e-01],\n", " [ 9.2361e-02, 6.0196e-01, 1.3081e-02, -8.1091e-02, 4.2342e-01,\n", " -8.8457e-02, -8.1851e-02, -1.1562e-01, -1.5049e-01, 4.9972e-01,\n", " -3.0432e-01, -7.8619e-01, 2.1060e-01, -1.0598e+00, -4.6542e-01,\n", " 4.2382e-01, -6.5671e-01, -4.8589e-01, 5.5977e-02, -2.9478e-02,\n", " 8.5718e-02, 4.7685e-01, 4.8351e-01, -2.8142e-01],\n", " [ 1.3377e-01, 5.4434e-01, 3.4505e-01, 1.1307e-01, 4.4057e-01,\n", " -7.6075e-03, 1.3841e-01, -1.1497e-01, -1.3177e-01, 8.0254e-01,\n", " -3.0627e-01, -6.8437e-01, 1.9035e-01, -1.0208e+00, -1.3259e-01,\n", " 5.3231e-01, -4.7814e-01, -5.1266e-01, 2.4646e-02, -3.0552e-01,\n", " 2.7398e-01, 5.8269e-01, 6.5481e-01, -4.2041e-01],\n", " [ 1.9604e-01, 4.0597e-01, 1.9071e-01, -2.5535e-01, 1.1915e-01,\n", " -6.7129e-02, 5.4386e-03, -8.2196e-02, -4.2803e-01, 7.0287e-01,\n", " -3.0026e-01, -7.6001e-01, -5.1471e-03, -7.0283e-01, -9.2978e-02,\n", " 1.2243e-01, -1.8398e-01, -4.7374e-01, 2.7978e-01, -3.6962e-01,\n", " 5.6046e-02, 4.1773e-01, 4.9894e-01, -3.1945e-01],\n", " [ 1.2657e-01, 3.3224e-01, 6.2830e-02, 1.5718e-01, 4.8844e-01,\n", " -1.1476e-01, -1.5044e-01, 2.5265e-02, -2.0351e-01, 5.5770e-01,\n", " -3.6036e-01, -7.4406e-01, 1.6962e-01, -9.6185e-01, -2.9334e-01,\n", " 2.2584e-01, -4.1169e-01, -5.2146e-01, 2.3314e-01, -1.3668e-01,\n", " -1.9598e-02, 3.8727e-01, 3.6892e-01, -3.3071e-01],\n", " [ 5.2178e-01, 6.9704e-01, 5.0093e-01, 1.1157e-01, 8.0012e-02,\n", " 3.6931e-01, -6.4927e-02, 1.1126e-01, -2.5117e-01, 5.3017e-01,\n", " -2.6488e-01, -8.4056e-01, 2.2374e-01, -6.6831e-01, -1.9402e-01,\n", " 7.4174e-02, -4.7763e-01, -2.6912e-01, 5.1009e-01, -5.4239e-01,\n", " 3.0123e-01, 3.7529e-01, 4.1625e-01, -2.0141e-01],\n", " [ 3.7968e-01, 4.9387e-01, 3.6786e-01, -1.3131e-01, 2.4445e-02,\n", " 2.2155e-01, -4.0087e-02, -1.4872e-01, -5.5030e-01, 6.8958e-01,\n", " -3.8156e-01, -7.5760e-01, 3.2085e-01, -6.4571e-01, 1.1268e-03,\n", " 3.4251e-02, -2.6440e-01, -2.6374e-01, 5.9787e-01, -4.6502e-01,\n", " 2.0074e-01, 4.5471e-01, 2.4238e-01, -4.3247e-01],\n", " [ 2.9364e-01, 4.8659e-01, 9.0845e-02, 1.6348e-01, 5.7636e-01,\n", " 4.5485e-01, -1.6781e-01, -1.4557e-01, -8.8814e-02, 6.6351e-01,\n", " -5.3669e-01, -8.2818e-01, 6.0474e-01, -9.4558e-01, -3.0133e-01,\n", " 3.0310e-01, -5.2493e-01, -2.5948e-01, 1.5857e-01, -4.2695e-01,\n", " 2.1311e-01, 4.6502e-01, 8.7946e-02, -5.5815e-01],\n", " [ 9.2208e-02, 2.9731e-01, 3.3849e-01, -5.1049e-02, 2.7834e-01,\n", " -1.1120e-01, 1.1835e-01, 1.3665e-01, -2.1291e-01, 3.5107e-01,\n", " -9.8108e-02, -5.0180e-01, 2.9894e-01, -7.7726e-01, -8.1317e-02,\n", " 3.5704e-01, -3.6759e-01, -2.2148e-01, 1.1019e-01, -1.4452e-02,\n", " 1.5092e-02, 3.3405e-01, 1.2765e-01, -4.0411e-01],\n", " [ 2.8927e-02, 4.4180e-01, 1.0994e-01, 5.6124e-01, 4.7174e-01,\n", " 1.9914e-01, -9.5047e-02, 3.1277e-02, -1.8656e-01, 5.0631e-01,\n", " -3.4353e-01, -5.7425e-01, 4.3409e-01, -8.3343e-01, -1.1627e-01,\n", " 3.1852e-02, -4.1274e-01, -2.6756e-01, 4.9652e-01, -2.6137e-01,\n", " 2.8559e-02, 3.0587e-01, 3.6717e-01, -4.4303e-01],\n", " [-1.0741e-01, 1.3539e-01, 1.5746e-01, 2.1208e-01, 6.3745e-01,\n", " -2.1864e-01, -1.8820e-01, 2.1184e-01, -3.6832e-02, 3.0890e-01,\n", " -2.4719e-03, -3.3573e-01, 1.8479e-01, -9.2119e-01, -2.3361e-01,\n", " 8.9827e-02, -5.4372e-01, -4.4935e-01, 3.2967e-01, -9.2807e-02,\n", " 9.9241e-02, 4.1705e-01, 2.4728e-01, -4.8119e-01],\n", " [ 2.8125e-01, 5.3276e-01, 5.0110e-02, 2.0471e-01, 5.7750e-01,\n", " 4.6670e-02, -2.1400e-01, 6.8794e-03, -6.8737e-02, 4.2138e-01,\n", " -3.1261e-01, -7.3709e-01, 4.2001e-01, -9.9757e-01, -4.8091e-01,\n", " 2.9960e-01, -6.2133e-01, -4.0566e-01, 3.2191e-01, -1.0219e-02,\n", " 1.2901e-01, 3.9601e-01, 1.6291e-01, -3.3871e-01],\n", " [ 2.9181e-01, 5.5400e-01, 3.0462e-01, 2.2431e-02, 2.8480e-01,\n", " 4.4624e-01, -2.8859e-01, -1.4629e-01, -4.3573e-02, 2.9742e-01,\n", " -1.0100e-01, -4.3070e-01, 4.6713e-01, -3.7132e-01, -8.6748e-02,\n", " 2.5666e-01, -3.5361e-01, -2.3917e-02, 3.0071e-01, -3.2420e-01,\n", " 1.3375e-01, 3.4475e-01, 3.0642e-01, -4.3496e-01],\n", " [-7.7723e-04, 2.3828e-01, 2.3124e-01, 4.1347e-01, 6.8455e-01,\n", " -9.8319e-03, 1.3403e-01, 1.8460e-02, -1.4025e-01, 5.9780e-01,\n", " -3.7015e-01, -5.7865e-01, 4.9211e-01, -1.1262e+00, -2.1693e-01,\n", " 3.2002e-01, -2.9313e-01, -3.1941e-01, 9.8446e-02, -6.2767e-02,\n", " -9.8636e-03, 3.5712e-01, 2.8833e-01, -5.3506e-01]], device='cuda:0',\n", " grad_fn=)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t(x)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PosixPath('hej')" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.h" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "16" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.batch_size" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PosixPath('hej')" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.h" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "../text_recognizer/__init__.py\n", "../text_recognizer/callbacks/__init__.py\n", "../text_recognizer/callbacks/wandb_callbacks.py\n", "../text_recognizer/data/image_utils.py\n", "../text_recognizer/data/emnist.py\n", "../text_recognizer/data/iam_lines.py\n", "../text_recognizer/data/download_utils.py\n", "../text_recognizer/data/mappings.py\n", "../text_recognizer/data/iam_preprocessor.py\n", "../text_recognizer/data/__init__.py\n", "../text_recognizer/data/make_wordpieces.py\n", "../text_recognizer/data/iam_paragraphs.py\n", "../text_recognizer/data/sentence_generator.py\n", "../text_recognizer/data/emnist_lines.py\n", "../text_recognizer/data/build_transitions.py\n", "../text_recognizer/data/base_dataset.py\n", "../text_recognizer/data/base_data_module.py\n", "../text_recognizer/data/iam.py\n", "../text_recognizer/data/iam_synthetic_paragraphs.py\n", "../text_recognizer/data/transforms.py\n", "../text_recognizer/data/iam_extended_paragraphs.py\n", "../text_recognizer/networks/__init__.py\n", "../text_recognizer/networks/util.py\n", "../text_recognizer/networks/cnn_tranformer.py\n", "../text_recognizer/networks/encoders/__init__.py\n", "../text_recognizer/networks/encoders/efficientnet/efficientnet.py\n", "../text_recognizer/networks/encoders/efficientnet/__init__.py\n", "../text_recognizer/networks/encoders/efficientnet/utils.py\n", "../text_recognizer/networks/encoders/efficientnet/mbconv.py\n", "../text_recognizer/networks/loss/__init__.py\n", "../text_recognizer/networks/loss/label_smoothing_loss.py\n", "../text_recognizer/networks/vqvae/__init__.py\n", "../text_recognizer/networks/vqvae/decoder.py\n", "../text_recognizer/networks/vqvae/vqvae.py\n", "../text_recognizer/networks/vqvae/vector_quantizer.py\n", "../text_recognizer/networks/vqvae/encoder.py\n", "../text_recognizer/networks/transformer/__init__.py\n", "../text_recognizer/networks/transformer/layers.py\n", "../text_recognizer/networks/transformer/residual.py\n", "../text_recognizer/networks/transformer/attention.py\n", "../text_recognizer/networks/transformer/transformer.py\n", "../text_recognizer/networks/transformer/vit.py\n", "../text_recognizer/networks/transformer/mlp.py\n", "../text_recognizer/networks/transformer/norm.py\n", "../text_recognizer/networks/transformer/positional_encodings/positional_encoding.py\n", "../text_recognizer/networks/transformer/positional_encodings/__init__.py\n", "../text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py\n", "../text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py\n", "../text_recognizer/networks/transformer/nystromer/__init__.py\n", "../text_recognizer/networks/transformer/nystromer/nystromer.py\n", "../text_recognizer/networks/transformer/nystromer/attention.py\n", "../text_recognizer/models/__init__.py\n", "../text_recognizer/models/base.py\n", "../text_recognizer/models/vqvae.py\n", "../text_recognizer/models/transformer.py\n", "../text_recognizer/models/dino.py\n", "../text_recognizer/models/metrics.py\n" ] } ], "source": [ "for f in Path(\"../text_recognizer\").glob(\"**/*.py\"):\n", " print(f)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Path(\"..\").glob(\"**/*.py\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": false }, "outputs": [], "source": [ "en = EfficientNet(\"b0\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "summary(en, (1, 224, 224));" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.cuda.is_available()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "decoder = Decoder(dim=128, depth=2, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "decoder.cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "transformer_decoder = Transformer(num_tokens=1003, max_seq_len=451, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "transformer_decoder.cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "efficient_transformer = Nystromer(\n", " dim = 64,\n", " depth = 4,\n", " num_heads = 8,\n", " num_landmarks = 64\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "v = ViT(\n", " dim = 64,\n", " image_size = (576, 640),\n", " patch_size = (32, 32),\n", " transformer = efficient_transformer\n", ").cuda()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "t = torch.randn(4, 1, 576, 640).cuda()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "EfficientNet(\n", " (_conv_stem): Sequential(\n", " (0): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)\n", " (1): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)\n", " (2): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (3): Mish(inplace=True)\n", " )\n", " (_blocks): ModuleList(\n", " (0): MBConvBlock(\n", " (_depthwise): Sequential(\n", " (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)\n", " (1): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(16, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96, bias=False)\n", " (1): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (2): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), groups=144, bias=False)\n", " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(36, 144, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (3): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(144, 144, kernel_size=(5, 5), stride=(2, 2), groups=144, bias=False)\n", " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(36, 144, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(144, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (4): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), groups=240, bias=False)\n", " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(240, 60, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(60, 240, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (5): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(240, 240, kernel_size=(3, 3), stride=(2, 2), groups=240, bias=False)\n", " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(240, 60, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(60, 240, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(240, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (6): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False)\n", " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (7): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False)\n", " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (8): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(480, 480, kernel_size=(5, 5), stride=(1, 1), groups=480, bias=False)\n", " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(480, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (9): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False)\n", " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (10): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False)\n", " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (11): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(2, 2), groups=672, bias=False)\n", " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(672, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (12): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n", " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (13): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n", " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (14): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n", " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (15): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), groups=1152, bias=False)\n", " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(1152, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(320, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " )\n", " (_conv_head): Sequential(\n", " (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(1280, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", ")" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "en.cuda()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 1280, 18, 20])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "en(t).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "o = v(t)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "caption = torch.randint(0, 90, (16, 690)).cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "o.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "caption.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "o = torch.randn(16, 20 * 18, 128).cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "caption = torch.randint(0, 1000, (16, 200)).cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "transformer_decoder(caption, context = o).shape # (1, 1024, 20000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "efficient_transformer = efficient_transformer(num_landmarks=256)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "efficient_transformer()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from omegaconf import OmegaConf" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = \"../training/configs/vqvae.yaml\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "conf = OmegaConf.load(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(OmegaConf.to_yaml(conf))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks import VQVAE" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae = VQVAE(**conf.network.args)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "datum = torch.randn([2, 1, 576, 640])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae.encoder(datum)[0].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae(datum)[0].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "datum = torch.randn([2, 1, 576, 640])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trg = torch.randint(0, 1000, [2, 682])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trg.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "datum = torch.randn([2, 1, 224, 224])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "en(t).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = \"../training/configs/cnn_transformer.yaml\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "conf = OmegaConf.load(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(OmegaConf.to_yaml(conf))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks.cnn_transformer import CNNTransformer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = CNNTransformer(input_shape=(1, 576, 640), output_shape=(682, 1), **conf.network.args)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t.encode(datum).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trg.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t(datum, trg).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "b, n = 16, 128\n", "device = \"cpu\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = lambda: torch.ones((b, n), device=device).bool()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x().shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.ones((b, n), device=device).bool().shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(1, 1, 576, 640)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "576 // 32" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "640 // 32" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "18 * 20" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(1, 1, 144, 160)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from einops import rearrange" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "patch_size=16\n", "p = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "p.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 4 }