diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-06-06 23:19:35 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-06-06 23:19:35 +0200 |
commit | 01d6e5fc066969283df99c759609df441151e9c5 (patch) | |
tree | ecd1459e142356d0c7f50a61307b760aca813248 /notebooks | |
parent | f4688482b4898c0b342d6ae59839dc27fbf856c6 (diff) |
Working on fixing decoder transformer
Diffstat (limited to 'notebooks')
-rw-r--r-- | notebooks/00-scratch-pad.ipynb | 431 | ||||
-rw-r--r-- | notebooks/03-look-at-iam-paragraphs.ipynb | 2 |
2 files changed, 250 insertions, 183 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb index 3c44f2b..8db843c 100644 --- a/notebooks/00-scratch-pad.ipynb +++ b/notebooks/00-scratch-pad.ipynb @@ -2,18 +2,9 @@ "cells": [ { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", @@ -34,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -43,7 +34,7 @@ "True" ] }, - "execution_count": 7, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -54,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -63,187 +54,224 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "decoder = Decoder(dim=128, depth=4, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "decoder.cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.transformer.transformer import Transformer" + ] + }, + { + "cell_type": "code", + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ - "decoder = Decoder(dim=256, depth=4, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)" + "transformer_decoder = Transformer(num_tokens=90, max_seq_len=690, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Decoder(\n", - " (layers): ModuleList(\n", - " (0): ModuleList(\n", - " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", - " (1): Attention(\n", - " (qkv_fn): Sequential(\n", - " (0): Linear(in_features=256, out_features=49152, bias=False)\n", - " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", + "Transformer(\n", + " (attn_layers): Decoder(\n", + " (layers): ModuleList(\n", + " (0): ModuleList(\n", + " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (1): Attention(\n", + " (qkv_fn): Sequential(\n", + " (0): Linear(in_features=128, out_features=24576, bias=False)\n", + " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", + " )\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (fc): Linear(in_features=8192, out_features=128, bias=True)\n", " )\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (fc): Linear(in_features=16384, out_features=256, bias=True)\n", + " (2): Residual()\n", " )\n", - " (2): Residual()\n", - " )\n", - " (1): ModuleList(\n", - " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", - " (1): Attention(\n", - " (qkv_fn): Sequential(\n", - " (0): Linear(in_features=256, out_features=49152, bias=False)\n", - " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", + " (1): ModuleList(\n", + " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (1): Attention(\n", + " (qkv_fn): Sequential(\n", + " (0): Linear(in_features=128, out_features=24576, bias=False)\n", + " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", + " )\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (fc): Linear(in_features=8192, out_features=128, bias=True)\n", " )\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (fc): Linear(in_features=16384, out_features=256, bias=True)\n", + " (2): Residual()\n", " )\n", - " (2): Residual()\n", - " )\n", - " (2): ModuleList(\n", - " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", - " (1): FeedForward(\n", - " (mlp): Sequential(\n", - " (0): GEGLU(\n", - " (fc): Linear(in_features=256, out_features=2048, bias=True)\n", + " (2): ModuleList(\n", + " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (1): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=128, out_features=1024, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=512, out_features=128, bias=True)\n", " )\n", - " (1): Dropout(p=0.0, inplace=False)\n", - " (2): Linear(in_features=1024, out_features=256, bias=True)\n", " )\n", + " (2): Residual()\n", " )\n", - " (2): Residual()\n", - " )\n", - " (3): ModuleList(\n", - " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", - " (1): Attention(\n", - " (qkv_fn): Sequential(\n", - " (0): Linear(in_features=256, out_features=49152, bias=False)\n", - " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", + " (3): ModuleList(\n", + " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (1): Attention(\n", + " (qkv_fn): Sequential(\n", + " (0): Linear(in_features=128, out_features=24576, bias=False)\n", + " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", + " )\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (fc): Linear(in_features=8192, out_features=128, bias=True)\n", " )\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (fc): Linear(in_features=16384, out_features=256, bias=True)\n", + " (2): Residual()\n", " )\n", - " (2): Residual()\n", - " )\n", - " (4): ModuleList(\n", - " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", - " (1): Attention(\n", - " (qkv_fn): Sequential(\n", - " (0): Linear(in_features=256, out_features=49152, bias=False)\n", - " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", + " (4): ModuleList(\n", + " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (1): Attention(\n", + " (qkv_fn): Sequential(\n", + " (0): Linear(in_features=128, out_features=24576, bias=False)\n", + " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", + " )\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (fc): Linear(in_features=8192, out_features=128, bias=True)\n", " )\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (fc): Linear(in_features=16384, out_features=256, bias=True)\n", + " (2): Residual()\n", " )\n", - " (2): Residual()\n", - " )\n", - " (5): ModuleList(\n", - " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", - " (1): FeedForward(\n", - " (mlp): Sequential(\n", - " (0): GEGLU(\n", - " (fc): Linear(in_features=256, out_features=2048, bias=True)\n", + " (5): ModuleList(\n", + " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (1): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=128, out_features=1024, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=512, out_features=128, bias=True)\n", " )\n", - " (1): Dropout(p=0.0, inplace=False)\n", - " (2): Linear(in_features=1024, out_features=256, bias=True)\n", " )\n", + " (2): Residual()\n", " )\n", - " (2): Residual()\n", - " )\n", - " (6): ModuleList(\n", - " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", - " (1): Attention(\n", - " (qkv_fn): Sequential(\n", - " (0): Linear(in_features=256, out_features=49152, bias=False)\n", - " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", + " (6): ModuleList(\n", + " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (1): Attention(\n", + " (qkv_fn): Sequential(\n", + " (0): Linear(in_features=128, out_features=24576, bias=False)\n", + " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", + " )\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (fc): Linear(in_features=8192, out_features=128, bias=True)\n", " )\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (fc): Linear(in_features=16384, out_features=256, bias=True)\n", + " (2): Residual()\n", " )\n", - " (2): Residual()\n", - " )\n", - " (7): ModuleList(\n", - " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", - " (1): Attention(\n", - " (qkv_fn): Sequential(\n", - " (0): Linear(in_features=256, out_features=49152, bias=False)\n", - " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", + " (7): ModuleList(\n", + " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (1): Attention(\n", + " (qkv_fn): Sequential(\n", + " (0): Linear(in_features=128, out_features=24576, bias=False)\n", + " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", + " )\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (fc): Linear(in_features=8192, out_features=128, bias=True)\n", " )\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (fc): Linear(in_features=16384, out_features=256, bias=True)\n", + " (2): Residual()\n", " )\n", - " (2): Residual()\n", - " )\n", - " (8): ModuleList(\n", - " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", - " (1): FeedForward(\n", - " (mlp): Sequential(\n", - " (0): GEGLU(\n", - " (fc): Linear(in_features=256, out_features=2048, bias=True)\n", + " (8): ModuleList(\n", + " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (1): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=128, out_features=1024, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=512, out_features=128, bias=True)\n", " )\n", - " (1): Dropout(p=0.0, inplace=False)\n", - " (2): Linear(in_features=1024, out_features=256, bias=True)\n", " )\n", + " (2): Residual()\n", " )\n", - " (2): Residual()\n", - " )\n", - " (9): ModuleList(\n", - " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", - " (1): Attention(\n", - " (qkv_fn): Sequential(\n", - " (0): Linear(in_features=256, out_features=49152, bias=False)\n", - " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", + " (9): ModuleList(\n", + " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (1): Attention(\n", + " (qkv_fn): Sequential(\n", + " (0): Linear(in_features=128, out_features=24576, bias=False)\n", + " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", + " )\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (fc): Linear(in_features=8192, out_features=128, bias=True)\n", " )\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (fc): Linear(in_features=16384, out_features=256, bias=True)\n", + " (2): Residual()\n", " )\n", - " (2): Residual()\n", - " )\n", - " (10): ModuleList(\n", - " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", - " (1): Attention(\n", - " (qkv_fn): Sequential(\n", - " (0): Linear(in_features=256, out_features=49152, bias=False)\n", - " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", + " (10): ModuleList(\n", + " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (1): Attention(\n", + " (qkv_fn): Sequential(\n", + " (0): Linear(in_features=128, out_features=24576, bias=False)\n", + " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", + " )\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (fc): Linear(in_features=8192, out_features=128, bias=True)\n", " )\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (fc): Linear(in_features=16384, out_features=256, bias=True)\n", + " (2): Residual()\n", " )\n", - " (2): Residual()\n", - " )\n", - " (11): ModuleList(\n", - " (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", - " (1): FeedForward(\n", - " (mlp): Sequential(\n", - " (0): GEGLU(\n", - " (fc): Linear(in_features=256, out_features=2048, bias=True)\n", + " (11): ModuleList(\n", + " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (1): FeedForward(\n", + " (mlp): Sequential(\n", + " (0): GEGLU(\n", + " (fc): Linear(in_features=128, out_features=1024, bias=True)\n", + " )\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " (2): Linear(in_features=512, out_features=128, bias=True)\n", " )\n", - " (1): Dropout(p=0.0, inplace=False)\n", - " (2): Linear(in_features=1024, out_features=256, bias=True)\n", " )\n", + " (2): Residual()\n", " )\n", - " (2): Residual()\n", " )\n", " )\n", + " (token_emb): Embedding(90, 128)\n", + " (emb_dropout): Dropout(p=0.1, inplace=False)\n", + " (pos_emb): AbsolutePositionalEmbedding(\n", + " (emb): Embedding(690, 128)\n", + " )\n", + " (project_emb): Identity()\n", + " (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (logits): Linear(in_features=128, out_features=90, bias=True)\n", ")" ] }, - "execution_count": 14, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "decoder.cuda()" + "transformer_decoder.cuda()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -252,16 +280,7 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from functools import partial" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -275,7 +294,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -284,7 +303,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -298,29 +317,95 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ - "t = torch.randn(32, 1, 576, 640).cuda()" + "t = torch.randn(16, 1, 576, 640).cuda()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ - "v(t).shape" + "o = v(t)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ - "from text_recognizer.networks.encoders.efficientnet import EfficientNet" + "caption = torch.randint(0, 90, (16, 690)).cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([16, 360, 128])" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "o.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([16, 690])" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "caption.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "forward() missing 2 required positional arguments: 'context' and 'context_mask'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-29-2290911ad81b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtransformer_decoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcaption\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# (1, 1024, 20000)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/transformer.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, mask, return_embeddings, **kwargs)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproject_emb\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 60\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattn_layers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 61\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mreturn_embeddings\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/layers.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, context, mask, context_mask)\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlayer_type\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"a\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 91\u001b[0;31m \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mblock\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrotary_pos_emb\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrotary_pos_emb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 92\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mlayer_type\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"c\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mblock\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcontext_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: forward() missing 2 required positional arguments: 'context' and 'context_mask'" + ] + } + ], + "source": [ + "transformer_decoder(caption, context = o) # (1, 1024, 20000)" ] }, { @@ -329,7 +414,7 @@ "metadata": {}, "outputs": [], "source": [ - "en = EfficientNet()" + "from text_recognizer.networks.encoders.efficientnet import EfficientNet" ] }, { @@ -338,7 +423,7 @@ "metadata": {}, "outputs": [], "source": [ - "(576, 640) // (8, 8)" + "en = EfficientNet()" ] }, { @@ -347,7 +432,7 @@ "metadata": {}, "outputs": [], "source": [ - "(576 // 32) ** 2" + "en.cuda()" ] }, { @@ -482,24 +567,6 @@ "metadata": {}, "outputs": [], "source": [ - "from text_recognizer.networks.backbones.efficientnet import EfficientNet" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "en = EfficientNet()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ "datum = torch.randn([2, 1, 576, 640])" ] }, @@ -536,7 +603,7 @@ "metadata": {}, "outputs": [], "source": [ - "en(datum).shape" + "en(t).shape" ] }, { @@ -752,7 +819,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.9.5" } }, "nbformat": 4, diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb index b72019b..e05704d 100644 --- a/notebooks/03-look-at-iam-paragraphs.ipynb +++ b/notebooks/03-look-at-iam-paragraphs.ipynb @@ -505,7 +505,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.9.5" } }, "nbformat": 4, |