summaryrefslogtreecommitdiff
path: root/notebooks/00-scratch-pad.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks/00-scratch-pad.ipynb')
-rw-r--r--notebooks/00-scratch-pad.ipynb1669
1 files changed, 0 insertions, 1669 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb
deleted file mode 100644
index d0f4215..0000000
--- a/notebooks/00-scratch-pad.ipynb
+++ /dev/null
@@ -1,1669 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "%load_ext autoreload\n",
- "%autoreload 2\n",
- "\n",
- "%matplotlib inline\n",
- "import matplotlib.pyplot as plt\n",
- "import numpy as np\n",
- "from PIL import Image\n",
- "import torch.nn.functional as F\n",
- "import torch\n",
- "from torch import nn\n",
- "from torchsummary import summary\n",
- "from importlib.util import find_spec\n",
- "if find_spec(\"text_recognizer\") is None:\n",
- " import sys\n",
- " sys.path.append('..')\n",
- "\n",
- "from text_recognizer.networks.transformer.vit import ViT\n",
- "from text_recognizer.networks.transformer.transformer import Transformer\n",
- "from text_recognizer.networks.transformer.layers import Decoder"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [],
- "source": [
- "t = torch.randint(0, 5, (4, 4))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "36"
- ]
- },
- "execution_count": 19,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "576 // 16"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 22,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "40"
- ]
- },
- "execution_count": 22,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "640 // 16"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "1440"
- ]
- },
- "execution_count": 24,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "36 * 40"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([[0, 1, 2, 1],\n",
- " [1, 2, 3, 3],\n",
- " [2, 2, 3, 3],\n",
- " [4, 0, 2, 4]])"
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "t"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [],
- "source": [
- "t = torch.randint(0, 5, (1, 4, 4, 4))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([[[[2, 3, 3, 3],\n",
- " [3, 4, 4, 2],\n",
- " [2, 3, 0, 0],\n",
- " [4, 3, 4, 0]],\n",
- "\n",
- " [[3, 0, 3, 0],\n",
- " [1, 4, 1, 3],\n",
- " [2, 3, 3, 3],\n",
- " [2, 3, 3, 1]],\n",
- "\n",
- " [[1, 1, 0, 3],\n",
- " [1, 3, 0, 4],\n",
- " [3, 1, 4, 2],\n",
- " [3, 1, 4, 3]],\n",
- "\n",
- " [[3, 2, 3, 4],\n",
- " [3, 2, 3, 3],\n",
- " [0, 2, 2, 3],\n",
- " [4, 0, 3, 4]]]])"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "t"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([1, 4, 16])"
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "t.flatten(start_dim=2).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([[[2, 3, 3, 3, 3, 4, 4, 2, 2, 3, 0, 0, 4, 3, 4, 0],\n",
- " [3, 0, 3, 0, 1, 4, 1, 3, 2, 3, 3, 3, 2, 3, 3, 1],\n",
- " [1, 1, 0, 3, 1, 3, 0, 4, 3, 1, 4, 2, 3, 1, 4, 3],\n",
- " [3, 2, 3, 4, 3, 2, 3, 3, 0, 2, 2, 3, 4, 0, 3, 4]]])"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "t.flatten(start_dim=2)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "ename": "TypeError",
- "evalue": "__init__() got an unexpected keyword argument 'dim'",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m/tmp/ipykernel_6532/3641656095.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mflatten\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFlatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;31mTypeError\u001b[0m: __init__() got an unexpected keyword argument 'dim'"
- ]
- }
- ],
- "source": [
- "flatten = nn.Flatten(stdim=2)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "True"
- ]
- },
- "execution_count": 2,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "torch.cuda.is_available()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "loss = nn.CrossEntropyLoss()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 51,
- "metadata": {},
- "outputs": [],
- "source": [
- "o = torch.randn((4, 5, 4))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 52,
- "metadata": {},
- "outputs": [],
- "source": [
- "t = torch.randint(0, 5, (4, 4))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 53,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([4, 5, 4])"
- ]
- },
- "execution_count": 53,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "o.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 54,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([4, 4])"
- ]
- },
- "execution_count": 54,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "t.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 55,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([[0, 1, 3, 2],\n",
- " [1, 4, 4, 4],\n",
- " [1, 4, 2, 1],\n",
- " [2, 0, 4, 4]])"
- ]
- },
- "execution_count": 55,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "t"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 56,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([[[ 0.0647, -1.3831, 0.0266, 0.8528],\n",
- " [ 1.4976, 0.4153, 1.0353, 0.0154],\n",
- " [ 1.4562, -0.3568, 0.3599, -0.6222],\n",
- " [ 0.2773, 0.4563, 0.9282, -2.1445],\n",
- " [ 0.5191, 0.3683, -0.3469, 0.1355]],\n",
- "\n",
- " [[ 0.0424, -0.3215, 0.5662, -0.4217],\n",
- " [ 2.0793, 1.2817, 0.1559, -0.6900],\n",
- " [-1.1751, -0.3359, 1.7875, -0.3671],\n",
- " [-0.4553, -0.3952, -0.8633, 0.1538],\n",
- " [-1.3862, 0.4255, -2.2948, 0.0312]],\n",
- "\n",
- " [[-1.4257, 2.2662, 0.2670, -0.4330],\n",
- " [-0.3244, -0.8669, -0.2571, 0.8028],\n",
- " [ 0.9109, -0.2289, -1.2095, -0.9761],\n",
- " [-0.0156, 1.2403, -1.1967, 0.6841],\n",
- " [-0.8185, 0.2967, -2.1639, -0.7903]],\n",
- "\n",
- " [[-1.0425, 0.1426, 0.1383, 0.9784],\n",
- " [-1.2853, 1.4123, -0.2272, -0.3335],\n",
- " [ 1.5751, -0.7663, 0.9610, 0.5686],\n",
- " [ 0.9697, -1.5515, -0.8658, -0.5882],\n",
- " [-1.2467, 0.0539, 0.1208, -1.0297]]])"
- ]
- },
- "execution_count": 56,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "o"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 57,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor(1.8355)"
- ]
- },
- "execution_count": 57,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "loss(o, t)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 63,
- "metadata": {},
- "outputs": [
- {
- "ename": "TypeError",
- "evalue": "unsupported operand type(s) for |: 'int' and 'Tensor'",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m/tmp/ipykernel_9275/1867668791.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;31mTypeError\u001b[0m: unsupported operand type(s) for |: 'int' and 'Tensor'"
- ]
- }
- ],
- "source": [
- "t[:, 2] == 2 | t[:, 2] == 1"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 60,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([4, 1])"
- ]
- },
- "execution_count": 60,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "torch.argmax(o, dim=-1)[:, -1:].shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 43,
- "metadata": {},
- "outputs": [],
- "source": [
- "class LabelSmoothingLossCanonical(nn.Module):\n",
- " def __init__(self, smoothing=0.0, dim=-1):\n",
- " super(LabelSmoothingLossCanonical, self).__init__()\n",
- " self.confidence = 1.0 - smoothing\n",
- " self.smoothing = smoothing\n",
- " self.dim = dim\n",
- "\n",
- " def forward(self, pred, target):\n",
- " pred = pred.log_softmax(dim=self.dim)\n",
- " with torch.no_grad():\n",
- " # true_dist = pred.data.clone()\n",
- " true_dist = torch.zeros_like(pred)\n",
- " print(true_dist.shape)\n",
- " true_dist.scatter_(1, target.unsqueeze(1), self.confidence)\n",
- " print(true_dist.shape)\n",
- " print(true_dist)\n",
- " true_dist.masked_fill_((target == 4).unsqueeze(1), 0)\n",
- " print(true_dist)\n",
- " true_dist += self.smoothing / pred.size(self.dim)\n",
- " return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 44,
- "metadata": {},
- "outputs": [],
- "source": [
- "l = LabelSmoothingLossCanonical(0.1)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 45,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "torch.Size([1, 5, 4])\n",
- "torch.Size([1, 5, 4])\n",
- "tensor([[[0.0000, 0.0000, 0.0000, 0.0000],\n",
- " [0.0000, 0.0000, 0.0000, 0.0000],\n",
- " [0.9000, 0.9000, 0.0000, 0.9000],\n",
- " [0.0000, 0.0000, 0.0000, 0.0000],\n",
- " [0.0000, 0.0000, 0.9000, 0.0000]]])\n",
- "tensor([[[0.0000, 0.0000, 0.0000, 0.0000],\n",
- " [0.0000, 0.0000, 0.0000, 0.0000],\n",
- " [0.9000, 0.9000, 0.0000, 0.9000],\n",
- " [0.0000, 0.0000, 0.0000, 0.0000],\n",
- " [0.0000, 0.0000, 0.0000, 0.0000]]])\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "tensor(0.9438)"
- ]
- },
- "execution_count": 45,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "l(o, t)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [],
- "source": [
- "en = EfficientNet(\"b0\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [],
- "source": [
- "def generate_square_subsequent_mask(size: int) -> torch.Tensor:\n",
- " \"\"\"Generate a triangular (size, size) mask.\"\"\"\n",
- " mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)\n",
- " mask = mask.float().masked_fill(mask == 0, float(\"-inf\")).masked_fill(mask == 1, float(0.0))\n",
- " return mask"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([[0., -inf, -inf, -inf],\n",
- " [0., 0., -inf, -inf],\n",
- " [0., 0., 0., -inf],\n",
- " [0., 0., 0., 0.]])"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "generate_square_subsequent_mask(4)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [],
- "source": [
- "from torch import Tensor"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 52,
- "metadata": {},
- "outputs": [],
- "source": [
- "tgt = torch.randint(0, 4, (1, 4))\n",
- "tgt_mask = torch.ones_like(tgt).bool()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 53,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([[True, True, True, True]])"
- ]
- },
- "execution_count": 53,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "tgt_mask"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [],
- "source": [
- "def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor:\n",
- " \"\"\"Returns causal target mask.\"\"\"\n",
- " trg_pad_mask = (trg != pad_index)[:, None, None]\n",
- " trg_len = trg.shape[1]\n",
- " trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()\n",
- " trg_mask = trg_pad_mask & trg_sub_mask\n",
- " return trg_mask"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 54,
- "metadata": {},
- "outputs": [],
- "source": [
- "t = torch.randint(0, 6, (0, 4))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 55,
- "metadata": {},
- "outputs": [],
- "source": [
- "t = torch.Tensor([[0, 0, 0, 3, 3, 3]])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 58,
- "metadata": {},
- "outputs": [],
- "source": [
- "tt = t != 3"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 59,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([[ True, True, True, False, False, False]])"
- ]
- },
- "execution_count": 59,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "tt"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 43,
- "metadata": {},
- "outputs": [],
- "source": [
- "t = torch.cat((t, t))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 44,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 6])"
- ]
- },
- "execution_count": 44,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "t.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 45,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([[[[ True, False, False, False, False, False],\n",
- " [ True, True, False, False, False, False],\n",
- " [ True, True, True, False, False, False],\n",
- " [ True, True, True, False, False, False],\n",
- " [ True, True, True, False, False, False],\n",
- " [ True, True, True, False, False, False]]],\n",
- "\n",
- "\n",
- " [[[ True, False, False, False, False, False],\n",
- " [ True, True, False, False, False, False],\n",
- " [ True, True, True, False, False, False],\n",
- " [ True, True, True, False, False, False],\n",
- " [ True, True, True, False, False, False],\n",
- " [ True, True, True, False, False, False]]]])"
- ]
- },
- "execution_count": 45,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "target_padding_mask(t, 3)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "target_padding_mask()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "summary(en, (1, 224, 224));"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "torch.cuda.is_available()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "ename": "TypeError",
- "evalue": "__init__() missing 4 required positional arguments: 'attn_fn', 'norm_fn', 'ff_fn', and 'rotary_emb'",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m/tmp/ipykernel_13932/689714588.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdecoder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDecoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdepth\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_heads\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mff_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattn_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcross_attend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/layers.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0;34m\"causal\"\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Cannot set causality on decoder\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 106\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcausal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;31mTypeError\u001b[0m: __init__() missing 4 required positional arguments: 'attn_fn', 'norm_fn', 'ff_fn', and 'rotary_emb'"
- ]
- }
- ],
- "source": [
- "decoder = Decoder(dim=128, depth=2, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "decoder.cuda()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "transformer_decoder = Transformer(num_tokens=1003, max_seq_len=451, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "transformer_decoder.cuda()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "efficient_transformer = Nystromer(\n",
- " dim = 64,\n",
- " depth = 4,\n",
- " num_heads = 8,\n",
- " num_landmarks = 64\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "v = ViT(\n",
- " dim = 64,\n",
- " image_size = (576, 640),\n",
- " patch_size = (32, 32),\n",
- " transformer = efficient_transformer\n",
- ").cuda()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "t = torch.randn(4, 1, 576, 640).cuda()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "EfficientNet(\n",
- " (_conv_stem): Sequential(\n",
- " (0): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)\n",
- " (1): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)\n",
- " (2): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (3): Mish(inplace=True)\n",
- " )\n",
- " (_blocks): ModuleList(\n",
- " (0): MBConvBlock(\n",
- " (_depthwise): Sequential(\n",
- " (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)\n",
- " (1): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_squeeze_excite): Sequential(\n",
- " (0): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): Mish(inplace=True)\n",
- " (2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (_pointwise): Sequential(\n",
- " (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(16, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " )\n",
- " )\n",
- " (1): MBConvBlock(\n",
- " (_inverted_bottleneck): Sequential(\n",
- " (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_depthwise): Sequential(\n",
- " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96, bias=False)\n",
- " (1): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_squeeze_excite): Sequential(\n",
- " (0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): Mish(inplace=True)\n",
- " (2): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (_pointwise): Sequential(\n",
- " (0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " )\n",
- " )\n",
- " (2): MBConvBlock(\n",
- " (_inverted_bottleneck): Sequential(\n",
- " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_depthwise): Sequential(\n",
- " (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), groups=144, bias=False)\n",
- " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_squeeze_excite): Sequential(\n",
- " (0): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): Mish(inplace=True)\n",
- " (2): Conv2d(36, 144, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (_pointwise): Sequential(\n",
- " (0): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " )\n",
- " )\n",
- " (3): MBConvBlock(\n",
- " (_inverted_bottleneck): Sequential(\n",
- " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_depthwise): Sequential(\n",
- " (0): Conv2d(144, 144, kernel_size=(5, 5), stride=(2, 2), groups=144, bias=False)\n",
- " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_squeeze_excite): Sequential(\n",
- " (0): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): Mish(inplace=True)\n",
- " (2): Conv2d(36, 144, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (_pointwise): Sequential(\n",
- " (0): Conv2d(144, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " )\n",
- " )\n",
- " (4): MBConvBlock(\n",
- " (_inverted_bottleneck): Sequential(\n",
- " (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_depthwise): Sequential(\n",
- " (0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), groups=240, bias=False)\n",
- " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_squeeze_excite): Sequential(\n",
- " (0): Conv2d(240, 60, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): Mish(inplace=True)\n",
- " (2): Conv2d(60, 240, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (_pointwise): Sequential(\n",
- " (0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " )\n",
- " )\n",
- " (5): MBConvBlock(\n",
- " (_inverted_bottleneck): Sequential(\n",
- " (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_depthwise): Sequential(\n",
- " (0): Conv2d(240, 240, kernel_size=(3, 3), stride=(2, 2), groups=240, bias=False)\n",
- " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_squeeze_excite): Sequential(\n",
- " (0): Conv2d(240, 60, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): Mish(inplace=True)\n",
- " (2): Conv2d(60, 240, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (_pointwise): Sequential(\n",
- " (0): Conv2d(240, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " )\n",
- " )\n",
- " (6): MBConvBlock(\n",
- " (_inverted_bottleneck): Sequential(\n",
- " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_depthwise): Sequential(\n",
- " (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False)\n",
- " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_squeeze_excite): Sequential(\n",
- " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): Mish(inplace=True)\n",
- " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (_pointwise): Sequential(\n",
- " (0): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " )\n",
- " )\n",
- " (7): MBConvBlock(\n",
- " (_inverted_bottleneck): Sequential(\n",
- " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_depthwise): Sequential(\n",
- " (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False)\n",
- " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_squeeze_excite): Sequential(\n",
- " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): Mish(inplace=True)\n",
- " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (_pointwise): Sequential(\n",
- " (0): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " )\n",
- " )\n",
- " (8): MBConvBlock(\n",
- " (_inverted_bottleneck): Sequential(\n",
- " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_depthwise): Sequential(\n",
- " (0): Conv2d(480, 480, kernel_size=(5, 5), stride=(1, 1), groups=480, bias=False)\n",
- " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_squeeze_excite): Sequential(\n",
- " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): Mish(inplace=True)\n",
- " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (_pointwise): Sequential(\n",
- " (0): Conv2d(480, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " )\n",
- " )\n",
- " (9): MBConvBlock(\n",
- " (_inverted_bottleneck): Sequential(\n",
- " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_depthwise): Sequential(\n",
- " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False)\n",
- " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_squeeze_excite): Sequential(\n",
- " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): Mish(inplace=True)\n",
- " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (_pointwise): Sequential(\n",
- " (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " )\n",
- " )\n",
- " (10): MBConvBlock(\n",
- " (_inverted_bottleneck): Sequential(\n",
- " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_depthwise): Sequential(\n",
- " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False)\n",
- " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_squeeze_excite): Sequential(\n",
- " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): Mish(inplace=True)\n",
- " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (_pointwise): Sequential(\n",
- " (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " )\n",
- " )\n",
- " (11): MBConvBlock(\n",
- " (_inverted_bottleneck): Sequential(\n",
- " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_depthwise): Sequential(\n",
- " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(2, 2), groups=672, bias=False)\n",
- " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_squeeze_excite): Sequential(\n",
- " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): Mish(inplace=True)\n",
- " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (_pointwise): Sequential(\n",
- " (0): Conv2d(672, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " )\n",
- " )\n",
- " (12): MBConvBlock(\n",
- " (_inverted_bottleneck): Sequential(\n",
- " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_depthwise): Sequential(\n",
- " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n",
- " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_squeeze_excite): Sequential(\n",
- " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): Mish(inplace=True)\n",
- " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (_pointwise): Sequential(\n",
- " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " )\n",
- " )\n",
- " (13): MBConvBlock(\n",
- " (_inverted_bottleneck): Sequential(\n",
- " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_depthwise): Sequential(\n",
- " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n",
- " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_squeeze_excite): Sequential(\n",
- " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): Mish(inplace=True)\n",
- " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (_pointwise): Sequential(\n",
- " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " )\n",
- " )\n",
- " (14): MBConvBlock(\n",
- " (_inverted_bottleneck): Sequential(\n",
- " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_depthwise): Sequential(\n",
- " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n",
- " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_squeeze_excite): Sequential(\n",
- " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): Mish(inplace=True)\n",
- " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (_pointwise): Sequential(\n",
- " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " )\n",
- " )\n",
- " (15): MBConvBlock(\n",
- " (_inverted_bottleneck): Sequential(\n",
- " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_depthwise): Sequential(\n",
- " (0): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), groups=1152, bias=False)\n",
- " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " (2): Mish(inplace=True)\n",
- " )\n",
- " (_squeeze_excite): Sequential(\n",
- " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): Mish(inplace=True)\n",
- " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (_pointwise): Sequential(\n",
- " (0): Conv2d(1152, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(320, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " )\n",
- " )\n",
- " )\n",
- " (_conv_head): Sequential(\n",
- " (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (1): BatchNorm2d(1280, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
- " )\n",
- ")"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "en.cuda()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([4, 1280, 18, 20])"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "en(t).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "o = v(t)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "caption = torch.randint(0, 90, (16, 690)).cuda()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "o.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "caption.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "o = torch.randn(16, 20 * 18, 128).cuda()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "caption = torch.randint(0, 1000, (16, 200)).cuda()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "transformer_decoder(caption, context = o).shape # (1, 1024, 20000)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "efficient_transformer = efficient_transformer(num_landmarks=256)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "efficient_transformer()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from omegaconf import OmegaConf"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "path = \"../training/configs/vqvae.yaml\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "conf = OmegaConf.load(path)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "print(OmegaConf.to_yaml(conf))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from text_recognizer.networks import VQVAE"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "vae = VQVAE(**conf.network.args)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "vae"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "datum = torch.randn([2, 1, 576, 640])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "vae.encoder(datum)[0].shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "vae(datum)[0].shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "datum = torch.randn([2, 1, 576, 640])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "trg = torch.randint(0, 1000, [2, 682])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "trg.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "datum = torch.randn([2, 1, 224, 224])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "en(t).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "path = \"../training/configs/cnn_transformer.yaml\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "conf = OmegaConf.load(path)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "print(OmegaConf.to_yaml(conf))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from text_recognizer.networks.cnn_transformer import CNNTransformer"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "t = CNNTransformer(input_shape=(1, 576, 640), output_shape=(682, 1), **conf.network.args)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "t.encode(datum).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "trg.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "t(datum, trg).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "b, n = 16, 128\n",
- "device = \"cpu\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "x = lambda: torch.ones((b, n), device=device).bool()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "x().shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "torch.ones((b, n), device=device).bool().shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "x = torch.randn(1, 1, 576, 640)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "576 // 32"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "640 // 32"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "18 * 20"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "x = torch.randn(1, 1, 144, 160)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from einops import rearrange"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "patch_size=16\n",
- "p = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "p.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.7"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}