{ "cells": [ { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The autoreload extension is already loaded. To reload it, use:\n", " %reload_ext autoreload\n" ] } ], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from PIL import Image\n", "import torch.nn.functional as F\n", "import torch\n", "from torch import nn\n", "from torchsummary import summary\n", "from importlib.util import find_spec\n", "if find_spec(\"text_recognizer\") is None:\n", " import sys\n", " sys.path.append('..')\n", "\n", "from text_recognizer.networks.transformer.vit import ViT\n", "from text_recognizer.networks.transformer.transformer import Transformer\n", "from text_recognizer.networks.transformer.layers import Decoder\n", "from text_recognizer.networks.transformer.nystromer.nystromer import Nystromer" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": false }, "outputs": [], "source": [ "en = EfficientNet(\"b0\")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def generate_square_subsequent_mask(size: int) -> torch.Tensor:\n", " \"\"\"Generate a triangular (size, size) mask.\"\"\"\n", " mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)\n", " mask = mask.float().masked_fill(mask == 0, float(\"-inf\")).masked_fill(mask == 1, float(0.0))\n", " return mask" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., -inf, -inf, -inf],\n", " [0., 0., -inf, -inf],\n", " [0., 0., 0., -inf],\n", " [0., 0., 0., 0.]])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "generate_square_subsequent_mask(4)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from torch import Tensor" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "tgt = torch.randint(0, 4, (1, 4))\n", "tgt_mask = torch.ones_like(tgt).bool()" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[True, True, True, True]])" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tgt_mask" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor:\n", " \"\"\"Returns causal target mask.\"\"\"\n", " trg_pad_mask = (trg != pad_index)[:, None, None]\n", " trg_len = trg.shape[1]\n", " trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()\n", " trg_mask = trg_pad_mask & trg_sub_mask\n", " return trg_mask" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "t = torch.randint(0, 6, (0, 4))" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "t = torch.Tensor([[0, 0, 0, 3, 3, 3]])" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "tt = t != 3" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ True, True, True, False, False, False]])" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tt" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "t = torch.cat((t, t))" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 6])" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.shape" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[[ True, False, False, False, False, False],\n", " [ True, True, False, False, False, False],\n", " [ True, True, True, False, False, False],\n", " [ True, True, True, False, False, False],\n", " [ True, True, True, False, False, False],\n", " [ True, True, True, False, False, False]]],\n", "\n", "\n", " [[[ True, False, False, False, False, False],\n", " [ True, True, False, False, False, False],\n", " [ True, True, True, False, False, False],\n", " [ True, True, True, False, False, False],\n", " [ True, True, True, False, False, False],\n", " [ True, True, True, False, False, False]]]])" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target_padding_mask(t, 3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "target_padding_mask()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "summary(en, (1, 224, 224));" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.cuda.is_available()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "decoder = Decoder(dim=128, depth=2, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "decoder.cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "transformer_decoder = Transformer(num_tokens=1003, max_seq_len=451, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "transformer_decoder.cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "efficient_transformer = Nystromer(\n", " dim = 64,\n", " depth = 4,\n", " num_heads = 8,\n", " num_landmarks = 64\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "v = ViT(\n", " dim = 64,\n", " image_size = (576, 640),\n", " patch_size = (32, 32),\n", " transformer = efficient_transformer\n", ").cuda()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "t = torch.randn(4, 1, 576, 640).cuda()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "EfficientNet(\n", " (_conv_stem): Sequential(\n", " (0): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)\n", " (1): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)\n", " (2): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (3): Mish(inplace=True)\n", " )\n", " (_blocks): ModuleList(\n", " (0): MBConvBlock(\n", " (_depthwise): Sequential(\n", " (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)\n", " (1): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(16, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96, bias=False)\n", " (1): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (2): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), groups=144, bias=False)\n", " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(36, 144, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (3): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(144, 144, kernel_size=(5, 5), stride=(2, 2), groups=144, bias=False)\n", " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(36, 144, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(144, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (4): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), groups=240, bias=False)\n", " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(240, 60, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(60, 240, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (5): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(240, 240, kernel_size=(3, 3), stride=(2, 2), groups=240, bias=False)\n", " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(240, 60, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(60, 240, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(240, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (6): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False)\n", " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (7): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False)\n", " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (8): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(480, 480, kernel_size=(5, 5), stride=(1, 1), groups=480, bias=False)\n", " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(480, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (9): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False)\n", " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (10): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False)\n", " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (11): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(2, 2), groups=672, bias=False)\n", " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(672, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (12): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n", " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (13): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n", " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (14): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n", " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (15): MBConvBlock(\n", " (_inverted_bottleneck): Sequential(\n", " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_depthwise): Sequential(\n", " (0): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), groups=1152, bias=False)\n", " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " (2): Mish(inplace=True)\n", " )\n", " (_squeeze_excite): Sequential(\n", " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", " (1): Mish(inplace=True)\n", " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (_pointwise): Sequential(\n", " (0): Conv2d(1152, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(320, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", " )\n", " )\n", " (_conv_head): Sequential(\n", " (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(1280, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", " )\n", ")" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "en.cuda()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 1280, 18, 20])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "en(t).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "o = v(t)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "caption = torch.randint(0, 90, (16, 690)).cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "o.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "caption.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "o = torch.randn(16, 20 * 18, 128).cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "caption = torch.randint(0, 1000, (16, 200)).cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "transformer_decoder(caption, context = o).shape # (1, 1024, 20000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "efficient_transformer = efficient_transformer(num_landmarks=256)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "efficient_transformer()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from omegaconf import OmegaConf" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = \"../training/configs/vqvae.yaml\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "conf = OmegaConf.load(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(OmegaConf.to_yaml(conf))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks import VQVAE" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae = VQVAE(**conf.network.args)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "datum = torch.randn([2, 1, 576, 640])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae.encoder(datum)[0].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae(datum)[0].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "datum = torch.randn([2, 1, 576, 640])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trg = torch.randint(0, 1000, [2, 682])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trg.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "datum = torch.randn([2, 1, 224, 224])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "en(t).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = \"../training/configs/cnn_transformer.yaml\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "conf = OmegaConf.load(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(OmegaConf.to_yaml(conf))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks.cnn_transformer import CNNTransformer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = CNNTransformer(input_shape=(1, 576, 640), output_shape=(682, 1), **conf.network.args)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t.encode(datum).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trg.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t(datum, trg).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "b, n = 16, 128\n", "device = \"cpu\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = lambda: torch.ones((b, n), device=device).bool()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x().shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.ones((b, n), device=device).bool().shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(1, 1, 576, 640)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "576 // 32" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "640 // 32" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "18 * 20" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(1, 1, 144, 160)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from einops import rearrange" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "patch_size=16\n", "p = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "p.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 4 }