From 7b8705f382b1642cf171cf7fcd01295104b9deef Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 30 Sep 2021 23:10:59 +0200 Subject: Delete legacy notebooks --- notebooks/00-scratch-pad.ipynb | 1669 ---------------------------------------- 1 file changed, 1669 deletions(-) delete mode 100644 notebooks/00-scratch-pad.ipynb (limited to 'notebooks/00-scratch-pad.ipynb') diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb deleted file mode 100644 index d0f4215..0000000 --- a/notebooks/00-scratch-pad.ipynb +++ /dev/null @@ -1,1669 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "%matplotlib inline\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "from PIL import Image\n", - "import torch.nn.functional as F\n", - "import torch\n", - "from torch import nn\n", - "from torchsummary import summary\n", - "from importlib.util import find_spec\n", - "if find_spec(\"text_recognizer\") is None:\n", - " import sys\n", - " sys.path.append('..')\n", - "\n", - "from text_recognizer.networks.transformer.vit import ViT\n", - "from text_recognizer.networks.transformer.transformer import Transformer\n", - "from text_recognizer.networks.transformer.layers import Decoder" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "t = torch.randint(0, 5, (4, 4))" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "36" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "576 // 16" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "40" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "640 // 16" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "1440" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "36 * 40" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[0, 1, 2, 1],\n", - " [1, 2, 3, 3],\n", - " [2, 2, 3, 3],\n", - " [4, 0, 2, 4]])" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "t = torch.randint(0, 5, (1, 4, 4, 4))" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[[2, 3, 3, 3],\n", - " [3, 4, 4, 2],\n", - " [2, 3, 0, 0],\n", - " [4, 3, 4, 0]],\n", - "\n", - " [[3, 0, 3, 0],\n", - " [1, 4, 1, 3],\n", - " [2, 3, 3, 3],\n", - " [2, 3, 3, 1]],\n", - "\n", - " [[1, 1, 0, 3],\n", - " [1, 3, 0, 4],\n", - " [3, 1, 4, 2],\n", - " [3, 1, 4, 3]],\n", - "\n", - " [[3, 2, 3, 4],\n", - " [3, 2, 3, 3],\n", - " [0, 2, 2, 3],\n", - " [4, 0, 3, 4]]]])" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 4, 16])" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.flatten(start_dim=2).shape" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[2, 3, 3, 3, 3, 4, 4, 2, 2, 3, 0, 0, 4, 3, 4, 0],\n", - " [3, 0, 3, 0, 1, 4, 1, 3, 2, 3, 3, 3, 2, 3, 3, 1],\n", - " [1, 1, 0, 3, 1, 3, 0, 4, 3, 1, 4, 2, 3, 1, 4, 3],\n", - " [3, 2, 3, 4, 3, 2, 3, 3, 0, 2, 2, 3, 4, 0, 3, 4]]])" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.flatten(start_dim=2)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "__init__() got an unexpected keyword argument 'dim'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_6532/3641656095.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mflatten\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFlatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m: __init__() got an unexpected keyword argument 'dim'" - ] - } - ], - "source": [ - "flatten = nn.Flatten(stdim=2)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.cuda.is_available()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "loss = nn.CrossEntropyLoss()" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "metadata": {}, - "outputs": [], - "source": [ - "o = torch.randn((4, 5, 4))" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "metadata": {}, - "outputs": [], - "source": [ - "t = torch.randint(0, 5, (4, 4))" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([4, 5, 4])" - ] - }, - "execution_count": 53, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "o.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([4, 4])" - ] - }, - "execution_count": 54, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[0, 1, 3, 2],\n", - " [1, 4, 4, 4],\n", - " [1, 4, 2, 1],\n", - " [2, 0, 4, 4]])" - ] - }, - "execution_count": 55, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[ 0.0647, -1.3831, 0.0266, 0.8528],\n", - " [ 1.4976, 0.4153, 1.0353, 0.0154],\n", - " [ 1.4562, -0.3568, 0.3599, -0.6222],\n", - " [ 0.2773, 0.4563, 0.9282, -2.1445],\n", - " [ 0.5191, 0.3683, -0.3469, 0.1355]],\n", - "\n", - " [[ 0.0424, -0.3215, 0.5662, -0.4217],\n", - " [ 2.0793, 1.2817, 0.1559, -0.6900],\n", - " [-1.1751, -0.3359, 1.7875, -0.3671],\n", - " [-0.4553, -0.3952, -0.8633, 0.1538],\n", - " [-1.3862, 0.4255, -2.2948, 0.0312]],\n", - "\n", - " [[-1.4257, 2.2662, 0.2670, -0.4330],\n", - " [-0.3244, -0.8669, -0.2571, 0.8028],\n", - " [ 0.9109, -0.2289, -1.2095, -0.9761],\n", - " [-0.0156, 1.2403, -1.1967, 0.6841],\n", - " [-0.8185, 0.2967, -2.1639, -0.7903]],\n", - "\n", - " [[-1.0425, 0.1426, 0.1383, 0.9784],\n", - " [-1.2853, 1.4123, -0.2272, -0.3335],\n", - " [ 1.5751, -0.7663, 0.9610, 0.5686],\n", - " [ 0.9697, -1.5515, -0.8658, -0.5882],\n", - " [-1.2467, 0.0539, 0.1208, -1.0297]]])" - ] - }, - "execution_count": 56, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "o" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(1.8355)" - ] - }, - "execution_count": 57, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "loss(o, t)" - ] - }, - { - "cell_type": "code", - "execution_count": 63, - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "unsupported operand type(s) for |: 'int' and 'Tensor'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_9275/1867668791.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m: unsupported operand type(s) for |: 'int' and 'Tensor'" - ] - } - ], - "source": [ - "t[:, 2] == 2 | t[:, 2] == 1" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([4, 1])" - ] - }, - "execution_count": 60, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.argmax(o, dim=-1)[:, -1:].shape" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "metadata": {}, - "outputs": [], - "source": [ - "class LabelSmoothingLossCanonical(nn.Module):\n", - " def __init__(self, smoothing=0.0, dim=-1):\n", - " super(LabelSmoothingLossCanonical, self).__init__()\n", - " self.confidence = 1.0 - smoothing\n", - " self.smoothing = smoothing\n", - " self.dim = dim\n", - "\n", - " def forward(self, pred, target):\n", - " pred = pred.log_softmax(dim=self.dim)\n", - " with torch.no_grad():\n", - " # true_dist = pred.data.clone()\n", - " true_dist = torch.zeros_like(pred)\n", - " print(true_dist.shape)\n", - " true_dist.scatter_(1, target.unsqueeze(1), self.confidence)\n", - " print(true_dist.shape)\n", - " print(true_dist)\n", - " true_dist.masked_fill_((target == 4).unsqueeze(1), 0)\n", - " print(true_dist)\n", - " true_dist += self.smoothing / pred.size(self.dim)\n", - " return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "metadata": {}, - "outputs": [], - "source": [ - "l = LabelSmoothingLossCanonical(0.1)" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([1, 5, 4])\n", - "torch.Size([1, 5, 4])\n", - "tensor([[[0.0000, 0.0000, 0.0000, 0.0000],\n", - " [0.0000, 0.0000, 0.0000, 0.0000],\n", - " [0.9000, 0.9000, 0.0000, 0.9000],\n", - " [0.0000, 0.0000, 0.0000, 0.0000],\n", - " [0.0000, 0.0000, 0.9000, 0.0000]]])\n", - "tensor([[[0.0000, 0.0000, 0.0000, 0.0000],\n", - " [0.0000, 0.0000, 0.0000, 0.0000],\n", - " [0.9000, 0.9000, 0.0000, 0.9000],\n", - " [0.0000, 0.0000, 0.0000, 0.0000],\n", - " [0.0000, 0.0000, 0.0000, 0.0000]]])\n" - ] - }, - { - "data": { - "text/plain": [ - "tensor(0.9438)" - ] - }, - "execution_count": 45, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "l(o, t)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "en = EfficientNet(\"b0\")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "def generate_square_subsequent_mask(size: int) -> torch.Tensor:\n", - " \"\"\"Generate a triangular (size, size) mask.\"\"\"\n", - " mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)\n", - " mask = mask.float().masked_fill(mask == 0, float(\"-inf\")).masked_fill(mask == 1, float(0.0))\n", - " return mask" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[0., -inf, -inf, -inf],\n", - " [0., 0., -inf, -inf],\n", - " [0., 0., 0., -inf],\n", - " [0., 0., 0., 0.]])" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "generate_square_subsequent_mask(4)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "from torch import Tensor" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "metadata": {}, - "outputs": [], - "source": [ - "tgt = torch.randint(0, 4, (1, 4))\n", - "tgt_mask = torch.ones_like(tgt).bool()" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[True, True, True, True]])" - ] - }, - "execution_count": 53, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tgt_mask" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor:\n", - " \"\"\"Returns causal target mask.\"\"\"\n", - " trg_pad_mask = (trg != pad_index)[:, None, None]\n", - " trg_len = trg.shape[1]\n", - " trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()\n", - " trg_mask = trg_pad_mask & trg_sub_mask\n", - " return trg_mask" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "metadata": {}, - "outputs": [], - "source": [ - "t = torch.randint(0, 6, (0, 4))" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [], - "source": [ - "t = torch.Tensor([[0, 0, 0, 3, 3, 3]])" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "metadata": {}, - "outputs": [], - "source": [ - "tt = t != 3" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[ True, True, True, False, False, False]])" - ] - }, - "execution_count": 59, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tt" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "metadata": {}, - "outputs": [], - "source": [ - "t = torch.cat((t, t))" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 6])" - ] - }, - "execution_count": 44, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[[ True, False, False, False, False, False],\n", - " [ True, True, False, False, False, False],\n", - " [ True, True, True, False, False, False],\n", - " [ True, True, True, False, False, False],\n", - " [ True, True, True, False, False, False],\n", - " [ True, True, True, False, False, False]]],\n", - "\n", - "\n", - " [[[ True, False, False, False, False, False],\n", - " [ True, True, False, False, False, False],\n", - " [ True, True, True, False, False, False],\n", - " [ True, True, True, False, False, False],\n", - " [ True, True, True, False, False, False],\n", - " [ True, True, True, False, False, False]]]])" - ] - }, - "execution_count": 45, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "target_padding_mask(t, 3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "target_padding_mask()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "summary(en, (1, 224, 224));" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "torch.cuda.is_available()" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "__init__() missing 4 required positional arguments: 'attn_fn', 'norm_fn', 'ff_fn', and 'rotary_emb'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_13932/689714588.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdecoder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDecoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdepth\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_heads\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mff_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattn_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcross_attend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/layers.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0;34m\"causal\"\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Cannot set causality on decoder\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 106\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcausal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m: __init__() missing 4 required positional arguments: 'attn_fn', 'norm_fn', 'ff_fn', and 'rotary_emb'" - ] - } - ], - "source": [ - "decoder = Decoder(dim=128, depth=2, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "decoder.cuda()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "transformer_decoder = Transformer(num_tokens=1003, max_seq_len=451, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "transformer_decoder.cuda()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "efficient_transformer = Nystromer(\n", - " dim = 64,\n", - " depth = 4,\n", - " num_heads = 8,\n", - " num_landmarks = 64\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "v = ViT(\n", - " dim = 64,\n", - " image_size = (576, 640),\n", - " patch_size = (32, 32),\n", - " transformer = efficient_transformer\n", - ").cuda()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "t = torch.randn(4, 1, 576, 640).cuda()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "EfficientNet(\n", - " (_conv_stem): Sequential(\n", - " (0): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)\n", - " (1): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)\n", - " (2): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (3): Mish(inplace=True)\n", - " )\n", - " (_blocks): ModuleList(\n", - " (0): MBConvBlock(\n", - " (_depthwise): Sequential(\n", - " (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)\n", - " (1): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_squeeze_excite): Sequential(\n", - " (0): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Mish(inplace=True)\n", - " (2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (_pointwise): Sequential(\n", - " (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(16, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (1): MBConvBlock(\n", - " (_inverted_bottleneck): Sequential(\n", - " (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_depthwise): Sequential(\n", - " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96, bias=False)\n", - " (1): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_squeeze_excite): Sequential(\n", - " (0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Mish(inplace=True)\n", - " (2): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (_pointwise): Sequential(\n", - " (0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (2): MBConvBlock(\n", - " (_inverted_bottleneck): Sequential(\n", - " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_depthwise): Sequential(\n", - " (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), groups=144, bias=False)\n", - " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_squeeze_excite): Sequential(\n", - " (0): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Mish(inplace=True)\n", - " (2): Conv2d(36, 144, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (_pointwise): Sequential(\n", - " (0): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (3): MBConvBlock(\n", - " (_inverted_bottleneck): Sequential(\n", - " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_depthwise): Sequential(\n", - " (0): Conv2d(144, 144, kernel_size=(5, 5), stride=(2, 2), groups=144, bias=False)\n", - " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_squeeze_excite): Sequential(\n", - " (0): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Mish(inplace=True)\n", - " (2): Conv2d(36, 144, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (_pointwise): Sequential(\n", - " (0): Conv2d(144, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (4): MBConvBlock(\n", - " (_inverted_bottleneck): Sequential(\n", - " (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_depthwise): Sequential(\n", - " (0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), groups=240, bias=False)\n", - " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_squeeze_excite): Sequential(\n", - " (0): Conv2d(240, 60, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Mish(inplace=True)\n", - " (2): Conv2d(60, 240, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (_pointwise): Sequential(\n", - " (0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (5): MBConvBlock(\n", - " (_inverted_bottleneck): Sequential(\n", - " (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_depthwise): Sequential(\n", - " (0): Conv2d(240, 240, kernel_size=(3, 3), stride=(2, 2), groups=240, bias=False)\n", - " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_squeeze_excite): Sequential(\n", - " (0): Conv2d(240, 60, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Mish(inplace=True)\n", - " (2): Conv2d(60, 240, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (_pointwise): Sequential(\n", - " (0): Conv2d(240, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (6): MBConvBlock(\n", - " (_inverted_bottleneck): Sequential(\n", - " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_depthwise): Sequential(\n", - " (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False)\n", - " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_squeeze_excite): Sequential(\n", - " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Mish(inplace=True)\n", - " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (_pointwise): Sequential(\n", - " (0): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (7): MBConvBlock(\n", - " (_inverted_bottleneck): Sequential(\n", - " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_depthwise): Sequential(\n", - " (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False)\n", - " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_squeeze_excite): Sequential(\n", - " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Mish(inplace=True)\n", - " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (_pointwise): Sequential(\n", - " (0): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (8): MBConvBlock(\n", - " (_inverted_bottleneck): Sequential(\n", - " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_depthwise): Sequential(\n", - " (0): Conv2d(480, 480, kernel_size=(5, 5), stride=(1, 1), groups=480, bias=False)\n", - " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_squeeze_excite): Sequential(\n", - " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Mish(inplace=True)\n", - " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (_pointwise): Sequential(\n", - " (0): Conv2d(480, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (9): MBConvBlock(\n", - " (_inverted_bottleneck): Sequential(\n", - " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_depthwise): Sequential(\n", - " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False)\n", - " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_squeeze_excite): Sequential(\n", - " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Mish(inplace=True)\n", - " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (_pointwise): Sequential(\n", - " (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (10): MBConvBlock(\n", - " (_inverted_bottleneck): Sequential(\n", - " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_depthwise): Sequential(\n", - " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False)\n", - " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_squeeze_excite): Sequential(\n", - " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Mish(inplace=True)\n", - " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (_pointwise): Sequential(\n", - " (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (11): MBConvBlock(\n", - " (_inverted_bottleneck): Sequential(\n", - " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_depthwise): Sequential(\n", - " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(2, 2), groups=672, bias=False)\n", - " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_squeeze_excite): Sequential(\n", - " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Mish(inplace=True)\n", - " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (_pointwise): Sequential(\n", - " (0): Conv2d(672, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (12): MBConvBlock(\n", - " (_inverted_bottleneck): Sequential(\n", - " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_depthwise): Sequential(\n", - " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n", - " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_squeeze_excite): Sequential(\n", - " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Mish(inplace=True)\n", - " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (_pointwise): Sequential(\n", - " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (13): MBConvBlock(\n", - " (_inverted_bottleneck): Sequential(\n", - " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_depthwise): Sequential(\n", - " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n", - " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_squeeze_excite): Sequential(\n", - " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Mish(inplace=True)\n", - " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (_pointwise): Sequential(\n", - " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (14): MBConvBlock(\n", - " (_inverted_bottleneck): Sequential(\n", - " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_depthwise): Sequential(\n", - " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n", - " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_squeeze_excite): Sequential(\n", - " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Mish(inplace=True)\n", - " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (_pointwise): Sequential(\n", - " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (15): MBConvBlock(\n", - " (_inverted_bottleneck): Sequential(\n", - " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_depthwise): Sequential(\n", - " (0): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), groups=1152, bias=False)\n", - " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " (2): Mish(inplace=True)\n", - " )\n", - " (_squeeze_excite): Sequential(\n", - " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Mish(inplace=True)\n", - " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (_pointwise): Sequential(\n", - " (0): Conv2d(1152, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(320, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " )\n", - " (_conv_head): Sequential(\n", - " (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): BatchNorm2d(1280, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", - " )\n", - ")" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "en.cuda()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([4, 1280, 18, 20])" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "en(t).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "o = v(t)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "caption = torch.randint(0, 90, (16, 690)).cuda()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "o.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "caption.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "o = torch.randn(16, 20 * 18, 128).cuda()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "caption = torch.randint(0, 1000, (16, 200)).cuda()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "transformer_decoder(caption, context = o).shape # (1, 1024, 20000)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "efficient_transformer = efficient_transformer(num_landmarks=256)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "efficient_transformer()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from omegaconf import OmegaConf" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "path = \"../training/configs/vqvae.yaml\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "conf = OmegaConf.load(path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(OmegaConf.to_yaml(conf))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks import VQVAE" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "vae = VQVAE(**conf.network.args)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "vae" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "datum = torch.randn([2, 1, 576, 640])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "vae.encoder(datum)[0].shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "vae(datum)[0].shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "datum = torch.randn([2, 1, 576, 640])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trg = torch.randint(0, 1000, [2, 682])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trg.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "datum = torch.randn([2, 1, 224, 224])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "en(t).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "path = \"../training/configs/cnn_transformer.yaml\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "conf = OmegaConf.load(path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(OmegaConf.to_yaml(conf))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks.cnn_transformer import CNNTransformer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t = CNNTransformer(input_shape=(1, 576, 640), output_shape=(682, 1), **conf.network.args)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t.encode(datum).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trg.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t(datum, trg).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "b, n = 16, 128\n", - "device = \"cpu\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "x = lambda: torch.ones((b, n), device=device).bool()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "x().shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "torch.ones((b, n), device=device).bool().shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "x = torch.randn(1, 1, 576, 640)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "576 // 32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "640 // 32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "18 * 20" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "x = torch.randn(1, 1, 144, 160)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from einops import rearrange" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "patch_size=16\n", - "p = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "p.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} -- cgit v1.2.3-70-g09d2