summaryrefslogtreecommitdiff
path: root/notebooks
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks')
-rw-r--r--notebooks/00-scratch-pad.ipynb321
1 files changed, 149 insertions, 172 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb
index 8db843c..4681360 100644
--- a/notebooks/00-scratch-pad.ipynb
+++ b/notebooks/00-scratch-pad.ipynb
@@ -20,7 +20,12 @@
"from importlib.util import find_spec\n",
"if find_spec(\"text_recognizer\") is None:\n",
" import sys\n",
- " sys.path.append('..')"
+ " sys.path.append('..')\n",
+ "\n",
+ "from text_recognizer.networks.transformer.vit import ViT\n",
+ "from text_recognizer.networks.transformer.transformer import Transformer\n",
+ "from text_recognizer.networks.transformer.layers import Decoder\n",
+ "from text_recognizer.networks.transformer.nystromer.nystromer import Nystromer"
]
},
{
@@ -49,23 +54,102 @@
"metadata": {},
"outputs": [],
"source": [
- "from text_recognizer.networks.transformer.layers import Decoder"
+ "decoder = Decoder(dim=64, depth=2, num_heads=4, ff_kwargs={}, attn_kwargs={}, cross_attend=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
- "outputs": [],
- "source": [
- "decoder = Decoder(dim=128, depth=4, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Decoder(\n",
+ " (layers): ModuleList(\n",
+ " (0): ModuleList(\n",
+ " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): Attention(\n",
+ " (qkv_fn): Sequential(\n",
+ " (0): Linear(in_features=64, out_features=12288, bias=False)\n",
+ " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n",
+ " )\n",
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
+ " (fc): Linear(in_features=4096, out_features=64, bias=True)\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " (1): ModuleList(\n",
+ " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): Attention(\n",
+ " (qkv_fn): Sequential(\n",
+ " (0): Linear(in_features=64, out_features=12288, bias=False)\n",
+ " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n",
+ " )\n",
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
+ " (fc): Linear(in_features=4096, out_features=64, bias=True)\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " (2): ModuleList(\n",
+ " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=64, out_features=512, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=256, out_features=64, bias=True)\n",
+ " )\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " (3): ModuleList(\n",
+ " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): Attention(\n",
+ " (qkv_fn): Sequential(\n",
+ " (0): Linear(in_features=64, out_features=12288, bias=False)\n",
+ " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n",
+ " )\n",
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
+ " (fc): Linear(in_features=4096, out_features=64, bias=True)\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " (4): ModuleList(\n",
+ " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): Attention(\n",
+ " (qkv_fn): Sequential(\n",
+ " (0): Linear(in_features=64, out_features=12288, bias=False)\n",
+ " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n",
+ " )\n",
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
+ " (fc): Linear(in_features=4096, out_features=64, bias=True)\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " (5): ModuleList(\n",
+ " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): FeedForward(\n",
+ " (mlp): Sequential(\n",
+ " (0): GEGLU(\n",
+ " (fc): Linear(in_features=64, out_features=512, bias=True)\n",
+ " )\n",
+ " (1): Dropout(p=0.0, inplace=False)\n",
+ " (2): Linear(in_features=256, out_features=64, bias=True)\n",
+ " )\n",
+ " )\n",
+ " (2): Residual()\n",
+ " )\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"decoder.cuda()"
]
@@ -76,22 +160,13 @@
"metadata": {},
"outputs": [],
"source": [
- "from text_recognizer.networks.transformer.transformer import Transformer"
+ "transformer_decoder = Transformer(num_tokens=90, max_seq_len=690, attn_layers=decoder, emb_dim=64, emb_dropout=0.1)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
- "outputs": [],
- "source": [
- "transformer_decoder = Transformer(num_tokens=90, max_seq_len=690, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
"outputs": [
{
"data": {
@@ -100,167 +175,93 @@
" (attn_layers): Decoder(\n",
" (layers): ModuleList(\n",
" (0): ModuleList(\n",
- " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
+ " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
" (1): Attention(\n",
" (qkv_fn): Sequential(\n",
- " (0): Linear(in_features=128, out_features=24576, bias=False)\n",
- " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n",
+ " (0): Linear(in_features=64, out_features=12288, bias=False)\n",
+ " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
- " (fc): Linear(in_features=8192, out_features=128, bias=True)\n",
+ " (fc): Linear(in_features=4096, out_features=64, bias=True)\n",
" )\n",
" (2): Residual()\n",
" )\n",
" (1): ModuleList(\n",
- " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
+ " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
" (1): Attention(\n",
" (qkv_fn): Sequential(\n",
- " (0): Linear(in_features=128, out_features=24576, bias=False)\n",
- " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n",
+ " (0): Linear(in_features=64, out_features=12288, bias=False)\n",
+ " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
- " (fc): Linear(in_features=8192, out_features=128, bias=True)\n",
+ " (fc): Linear(in_features=4096, out_features=64, bias=True)\n",
" )\n",
" (2): Residual()\n",
" )\n",
" (2): ModuleList(\n",
- " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
+ " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
" (1): FeedForward(\n",
" (mlp): Sequential(\n",
" (0): GEGLU(\n",
- " (fc): Linear(in_features=128, out_features=1024, bias=True)\n",
+ " (fc): Linear(in_features=64, out_features=512, bias=True)\n",
" )\n",
" (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=512, out_features=128, bias=True)\n",
+ " (2): Linear(in_features=256, out_features=64, bias=True)\n",
" )\n",
" )\n",
" (2): Residual()\n",
" )\n",
" (3): ModuleList(\n",
- " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
+ " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
" (1): Attention(\n",
" (qkv_fn): Sequential(\n",
- " (0): Linear(in_features=128, out_features=24576, bias=False)\n",
- " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n",
+ " (0): Linear(in_features=64, out_features=12288, bias=False)\n",
+ " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
- " (fc): Linear(in_features=8192, out_features=128, bias=True)\n",
+ " (fc): Linear(in_features=4096, out_features=64, bias=True)\n",
" )\n",
" (2): Residual()\n",
" )\n",
" (4): ModuleList(\n",
- " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
+ " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
" (1): Attention(\n",
" (qkv_fn): Sequential(\n",
- " (0): Linear(in_features=128, out_features=24576, bias=False)\n",
- " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n",
+ " (0): Linear(in_features=64, out_features=12288, bias=False)\n",
+ " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
- " (fc): Linear(in_features=8192, out_features=128, bias=True)\n",
+ " (fc): Linear(in_features=4096, out_features=64, bias=True)\n",
" )\n",
" (2): Residual()\n",
" )\n",
" (5): ModuleList(\n",
- " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
+ " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
" (1): FeedForward(\n",
" (mlp): Sequential(\n",
" (0): GEGLU(\n",
- " (fc): Linear(in_features=128, out_features=1024, bias=True)\n",
+ " (fc): Linear(in_features=64, out_features=512, bias=True)\n",
" )\n",
" (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=512, out_features=128, bias=True)\n",
- " )\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " (6): ModuleList(\n",
- " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
- " (1): Attention(\n",
- " (qkv_fn): Sequential(\n",
- " (0): Linear(in_features=128, out_features=24576, bias=False)\n",
- " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n",
- " )\n",
- " (dropout): Dropout(p=0.0, inplace=False)\n",
- " (fc): Linear(in_features=8192, out_features=128, bias=True)\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " (7): ModuleList(\n",
- " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
- " (1): Attention(\n",
- " (qkv_fn): Sequential(\n",
- " (0): Linear(in_features=128, out_features=24576, bias=False)\n",
- " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n",
- " )\n",
- " (dropout): Dropout(p=0.0, inplace=False)\n",
- " (fc): Linear(in_features=8192, out_features=128, bias=True)\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " (8): ModuleList(\n",
- " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
- " (1): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=128, out_features=1024, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=512, out_features=128, bias=True)\n",
- " )\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " (9): ModuleList(\n",
- " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
- " (1): Attention(\n",
- " (qkv_fn): Sequential(\n",
- " (0): Linear(in_features=128, out_features=24576, bias=False)\n",
- " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n",
- " )\n",
- " (dropout): Dropout(p=0.0, inplace=False)\n",
- " (fc): Linear(in_features=8192, out_features=128, bias=True)\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " (10): ModuleList(\n",
- " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
- " (1): Attention(\n",
- " (qkv_fn): Sequential(\n",
- " (0): Linear(in_features=128, out_features=24576, bias=False)\n",
- " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n",
- " )\n",
- " (dropout): Dropout(p=0.0, inplace=False)\n",
- " (fc): Linear(in_features=8192, out_features=128, bias=True)\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " (11): ModuleList(\n",
- " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
- " (1): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=128, out_features=1024, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=512, out_features=128, bias=True)\n",
+ " (2): Linear(in_features=256, out_features=64, bias=True)\n",
" )\n",
" )\n",
" (2): Residual()\n",
" )\n",
" )\n",
" )\n",
- " (token_emb): Embedding(90, 128)\n",
+ " (token_emb): Embedding(90, 64)\n",
" (emb_dropout): Dropout(p=0.1, inplace=False)\n",
" (pos_emb): AbsolutePositionalEmbedding(\n",
- " (emb): Embedding(690, 128)\n",
+ " (emb): Embedding(690, 64)\n",
" )\n",
" (project_emb): Identity()\n",
- " (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
- " (logits): Linear(in_features=128, out_features=90, bias=True)\n",
+ " (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
+ " (logits): Linear(in_features=64, out_features=90, bias=True)\n",
")"
]
},
- "execution_count": 7,
+ "execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@@ -271,62 +272,44 @@
},
{
"cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [],
- "source": [
- "from text_recognizer.networks.transformer.nystromer.nystromer import Nystromer"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
+ "execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"efficient_transformer = Nystromer(\n",
- " dim = 128,\n",
+ " dim = 64,\n",
" depth = 4,\n",
" num_heads = 8,\n",
- " num_landmarks = 128\n",
+ " num_landmarks = 64\n",
")"
]
},
{
"cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [],
- "source": [
- "from text_recognizer.networks.transformer.vit import ViT"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
+ "execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"v = ViT(\n",
- " dim = 128,\n",
+ " dim = 64,\n",
" image_size = (576, 640),\n",
- " patch_size = (32, 32),\n",
+ " patch_size = (64, 64),\n",
" transformer = efficient_transformer\n",
").cuda()"
]
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
- "t = torch.randn(16, 1, 576, 640).cuda()"
+ "t = torch.randn(4, 1, 576, 640).cuda()"
]
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@@ -335,7 +318,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -344,16 +327,16 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "torch.Size([16, 360, 128])"
+ "torch.Size([4, 90, 64])"
]
},
- "execution_count": 23,
+ "execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@@ -364,7 +347,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
@@ -373,7 +356,7 @@
"torch.Size([16, 690])"
]
},
- "execution_count": 24,
+ "execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@@ -384,28 +367,22 @@
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
- "ename": "TypeError",
- "evalue": "forward() missing 2 required positional arguments: 'context' and 'context_mask'",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m<ipython-input-29-2290911ad81b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtransformer_decoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcaption\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# (1, 1024, 20000)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/transformer.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, mask, return_embeddings, **kwargs)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproject_emb\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 60\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattn_layers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 61\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mreturn_embeddings\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/layers.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, context, mask, context_mask)\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlayer_type\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"a\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 91\u001b[0;31m \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mblock\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrotary_pos_emb\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrotary_pos_emb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 92\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mlayer_type\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"c\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mblock\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcontext_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mTypeError\u001b[0m: forward() missing 2 required positional arguments: 'context' and 'context_mask'"
- ]
+ "data": {
+ "text/plain": [
+ "torch.Size([16, 690, 90])"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
- "transformer_decoder(caption, context = o) # (1, 1024, 20000)"
+ "transformer_decoder(caption, context = o).shape # (1, 1024, 20000)"
]
},
{