{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from PIL import Image\n", "import torch.nn.functional as F\n", "import torch\n", "from torch import nn\n", "from torchsummary import summary\n", "from importlib.util import find_spec\n", "if find_spec(\"text_recognizer\") is None:\n", " import sys\n", " sys.path.append('..')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.cuda.is_available()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks.transformer.layers import Decoder" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "decoder = Decoder(dim=128, depth=4, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "decoder.cuda()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks.transformer.transformer import Transformer" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "transformer_decoder = Transformer(num_tokens=90, max_seq_len=690, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Transformer(\n", " (attn_layers): Decoder(\n", " (layers): ModuleList(\n", " (0): ModuleList(\n", " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (1): Attention(\n", " (qkv_fn): Sequential(\n", " (0): Linear(in_features=128, out_features=24576, bias=False)\n", " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (fc): Linear(in_features=8192, out_features=128, bias=True)\n", " )\n", " (2): Residual()\n", " )\n", " (1): ModuleList(\n", " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (1): Attention(\n", " (qkv_fn): Sequential(\n", " (0): Linear(in_features=128, out_features=24576, bias=False)\n", " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (fc): Linear(in_features=8192, out_features=128, bias=True)\n", " )\n", " (2): Residual()\n", " )\n", " (2): ModuleList(\n", " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (1): FeedForward(\n", " (mlp): Sequential(\n", " (0): GEGLU(\n", " (fc): Linear(in_features=128, out_features=1024, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=512, out_features=128, bias=True)\n", " )\n", " )\n", " (2): Residual()\n", " )\n", " (3): ModuleList(\n", " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (1): Attention(\n", " (qkv_fn): Sequential(\n", " (0): Linear(in_features=128, out_features=24576, bias=False)\n", " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (fc): Linear(in_features=8192, out_features=128, bias=True)\n", " )\n", " (2): Residual()\n", " )\n", " (4): ModuleList(\n", " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (1): Attention(\n", " (qkv_fn): Sequential(\n", " (0): Linear(in_features=128, out_features=24576, bias=False)\n", " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (fc): Linear(in_features=8192, out_features=128, bias=True)\n", " )\n", " (2): Residual()\n", " )\n", " (5): ModuleList(\n", " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (1): FeedForward(\n", " (mlp): Sequential(\n", " (0): GEGLU(\n", " (fc): Linear(in_features=128, out_features=1024, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=512, out_features=128, bias=True)\n", " )\n", " )\n", " (2): Residual()\n", " )\n", " (6): ModuleList(\n", " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (1): Attention(\n", " (qkv_fn): Sequential(\n", " (0): Linear(in_features=128, out_features=24576, bias=False)\n", " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (fc): Linear(in_features=8192, out_features=128, bias=True)\n", " )\n", " (2): Residual()\n", " )\n", " (7): ModuleList(\n", " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (1): Attention(\n", " (qkv_fn): Sequential(\n", " (0): Linear(in_features=128, out_features=24576, bias=False)\n", " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (fc): Linear(in_features=8192, out_features=128, bias=True)\n", " )\n", " (2): Residual()\n", " )\n", " (8): ModuleList(\n", " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (1): FeedForward(\n", " (mlp): Sequential(\n", " (0): GEGLU(\n", " (fc): Linear(in_features=128, out_features=1024, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=512, out_features=128, bias=True)\n", " )\n", " )\n", " (2): Residual()\n", " )\n", " (9): ModuleList(\n", " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (1): Attention(\n", " (qkv_fn): Sequential(\n", " (0): Linear(in_features=128, out_features=24576, bias=False)\n", " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (fc): Linear(in_features=8192, out_features=128, bias=True)\n", " )\n", " (2): Residual()\n", " )\n", " (10): ModuleList(\n", " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (1): Attention(\n", " (qkv_fn): Sequential(\n", " (0): Linear(in_features=128, out_features=24576, bias=False)\n", " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (fc): Linear(in_features=8192, out_features=128, bias=True)\n", " )\n", " (2): Residual()\n", " )\n", " (11): ModuleList(\n", " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (1): FeedForward(\n", " (mlp): Sequential(\n", " (0): GEGLU(\n", " (fc): Linear(in_features=128, out_features=1024, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=512, out_features=128, bias=True)\n", " )\n", " )\n", " (2): Residual()\n", " )\n", " )\n", " )\n", " (token_emb): Embedding(90, 128)\n", " (emb_dropout): Dropout(p=0.1, inplace=False)\n", " (pos_emb): AbsolutePositionalEmbedding(\n", " (emb): Embedding(690, 128)\n", " )\n", " (project_emb): Identity()\n", " (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (logits): Linear(in_features=128, out_features=90, bias=True)\n", ")" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "transformer_decoder.cuda()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks.transformer.nystromer.nystromer import Nystromer" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "efficient_transformer = Nystromer(\n", " dim = 128,\n", " depth = 4,\n", " num_heads = 8,\n", " num_landmarks = 128\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks.transformer.vit import ViT" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "v = ViT(\n", " dim = 128,\n", " image_size = (576, 640),\n", " patch_size = (32, 32),\n", " transformer = efficient_transformer\n", ").cuda()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "t = torch.randn(16, 1, 576, 640).cuda()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "o = v(t)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "caption = torch.randint(0, 90, (16, 690)).cuda()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([16, 360, 128])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "o.shape" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([16, 690])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "caption.shape" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "forward() missing 2 required positional arguments: 'context' and 'context_mask'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtransformer_decoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcaption\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# (1, 1024, 20000)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/transformer.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, mask, return_embeddings, **kwargs)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproject_emb\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 60\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattn_layers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 61\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mreturn_embeddings\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/layers.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, context, mask, context_mask)\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlayer_type\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"a\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 91\u001b[0;31m \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mblock\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrotary_pos_emb\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrotary_pos_emb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 92\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mlayer_type\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"c\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mblock\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcontext_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mTypeError\u001b[0m: forward() missing 2 required positional arguments: 'context' and 'context_mask'" ] } ], "source": [ "transformer_decoder(caption, context = o) # (1, 1024, 20000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks.encoders.efficientnet import EfficientNet" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "en = EfficientNet()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "en.cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "summary(en, (1, 576, 640))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "type(efficient_transformer)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "efficient_transformer = efficient_transformer(num_landmarks=256)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "efficient_transformer()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from omegaconf import OmegaConf" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = \"../training/configs/vqvae.yaml\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "conf = OmegaConf.load(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(OmegaConf.to_yaml(conf))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks import VQVAE" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae = VQVAE(**conf.network.args)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "datum = torch.randn([2, 1, 576, 640])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae.encoder(datum)[0].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae(datum)[0].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "datum = torch.randn([2, 1, 576, 640])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trg = torch.randint(0, 1000, [2, 682])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trg.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "datum = torch.randn([2, 1, 224, 224])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "en(t).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = \"../training/configs/cnn_transformer.yaml\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "conf = OmegaConf.load(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(OmegaConf.to_yaml(conf))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.networks.cnn_transformer import CNNTransformer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = CNNTransformer(input_shape=(1, 576, 640), output_shape=(682, 1), **conf.network.args)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t.encode(datum).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trg.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t(datum, trg).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "b, n = 16, 128\n", "device = \"cpu\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = lambda: torch.ones((b, n), device=device).bool()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x().shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.ones((b, n), device=device).bool().shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(1, 1, 576, 640)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "576 // 32" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "640 // 32" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "18 * 20" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(1, 1, 144, 160)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from einops import rearrange" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "patch_size=16\n", "p = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "p.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 4 }