diff options
| -rw-r--r-- | notebooks/00-scratch-pad.ipynb | 431 | ||||
| -rw-r--r-- | notebooks/03-look-at-iam-paragraphs.ipynb | 2 | ||||
| -rw-r--r-- | pyproject.toml | 2 | ||||
| -rw-r--r-- | text_recognizer/networks/cnn_transformer.py | 182 | ||||
| -rw-r--r-- | text_recognizer/networks/transducer/__init__.py | 3 | ||||
| -rw-r--r-- | text_recognizer/networks/transducer/tds_conv.py | 208 | ||||
| -rw-r--r-- | text_recognizer/networks/transducer/test.py | 60 | ||||
| -rw-r--r-- | text_recognizer/networks/transducer/transducer.py | 410 | ||||
| -rw-r--r-- | text_recognizer/networks/transformer/__init__.py | 2 | ||||
| -rw-r--r-- | text_recognizer/networks/transformer/layers.py | 5 | ||||
| -rw-r--r-- | text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py | 1 | ||||
| -rw-r--r-- | text_recognizer/networks/transformer/transformer.py | 7 | ||||
| -rw-r--r-- | text_recognizer/networks/util.py | 39 | 
13 files changed, 260 insertions, 1092 deletions
| diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb index 3c44f2b..8db843c 100644 --- a/notebooks/00-scratch-pad.ipynb +++ b/notebooks/00-scratch-pad.ipynb @@ -2,18 +2,9 @@   "cells": [    {     "cell_type": "code", -   "execution_count": 6, +   "execution_count": 1,     "metadata": {}, -   "outputs": [ -    { -     "name": "stdout", -     "output_type": "stream", -     "text": [ -      "The autoreload extension is already loaded. To reload it, use:\n", -      "  %reload_ext autoreload\n" -     ] -    } -   ], +   "outputs": [],     "source": [      "%load_ext autoreload\n",      "%autoreload 2\n", @@ -34,7 +25,7 @@    },    {     "cell_type": "code", -   "execution_count": 7, +   "execution_count": 2,     "metadata": {},     "outputs": [      { @@ -43,7 +34,7 @@         "True"        ]       }, -     "execution_count": 7, +     "execution_count": 2,       "metadata": {},       "output_type": "execute_result"      } @@ -54,7 +45,7 @@    },    {     "cell_type": "code", -   "execution_count": 8, +   "execution_count": 3,     "metadata": {},     "outputs": [],     "source": [ @@ -63,187 +54,224 @@    },    {     "cell_type": "code", -   "execution_count": 13, +   "execution_count": 4, +   "metadata": {}, +   "outputs": [], +   "source": [ +    "decoder = Decoder(dim=128, depth=4, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": null, +   "metadata": {}, +   "outputs": [], +   "source": [ +    "decoder.cuda()" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 5, +   "metadata": {}, +   "outputs": [], +   "source": [ +    "from text_recognizer.networks.transformer.transformer import Transformer" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 6,     "metadata": {},     "outputs": [],     "source": [ -    "decoder = Decoder(dim=256, depth=4, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)" +    "transformer_decoder = Transformer(num_tokens=90, max_seq_len=690, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)"     ]    },    {     "cell_type": "code", -   "execution_count": 14, +   "execution_count": 7,     "metadata": {},     "outputs": [      {       "data": {        "text/plain": [ -       "Decoder(\n", -       "  (layers): ModuleList(\n", -       "    (0): ModuleList(\n", -       "      (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", -       "      (1): Attention(\n", -       "        (qkv_fn): Sequential(\n", -       "          (0): Linear(in_features=256, out_features=49152, bias=False)\n", -       "          (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", +       "Transformer(\n", +       "  (attn_layers): Decoder(\n", +       "    (layers): ModuleList(\n", +       "      (0): ModuleList(\n", +       "        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", +       "        (1): Attention(\n", +       "          (qkv_fn): Sequential(\n", +       "            (0): Linear(in_features=128, out_features=24576, bias=False)\n", +       "            (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", +       "          )\n", +       "          (dropout): Dropout(p=0.0, inplace=False)\n", +       "          (fc): Linear(in_features=8192, out_features=128, bias=True)\n",         "        )\n", -       "        (dropout): Dropout(p=0.0, inplace=False)\n", -       "        (fc): Linear(in_features=16384, out_features=256, bias=True)\n", +       "        (2): Residual()\n",         "      )\n", -       "      (2): Residual()\n", -       "    )\n", -       "    (1): ModuleList(\n", -       "      (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", -       "      (1): Attention(\n", -       "        (qkv_fn): Sequential(\n", -       "          (0): Linear(in_features=256, out_features=49152, bias=False)\n", -       "          (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", +       "      (1): ModuleList(\n", +       "        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", +       "        (1): Attention(\n", +       "          (qkv_fn): Sequential(\n", +       "            (0): Linear(in_features=128, out_features=24576, bias=False)\n", +       "            (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", +       "          )\n", +       "          (dropout): Dropout(p=0.0, inplace=False)\n", +       "          (fc): Linear(in_features=8192, out_features=128, bias=True)\n",         "        )\n", -       "        (dropout): Dropout(p=0.0, inplace=False)\n", -       "        (fc): Linear(in_features=16384, out_features=256, bias=True)\n", +       "        (2): Residual()\n",         "      )\n", -       "      (2): Residual()\n", -       "    )\n", -       "    (2): ModuleList(\n", -       "      (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", -       "      (1): FeedForward(\n", -       "        (mlp): Sequential(\n", -       "          (0): GEGLU(\n", -       "            (fc): Linear(in_features=256, out_features=2048, bias=True)\n", +       "      (2): ModuleList(\n", +       "        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", +       "        (1): FeedForward(\n", +       "          (mlp): Sequential(\n", +       "            (0): GEGLU(\n", +       "              (fc): Linear(in_features=128, out_features=1024, bias=True)\n", +       "            )\n", +       "            (1): Dropout(p=0.0, inplace=False)\n", +       "            (2): Linear(in_features=512, out_features=128, bias=True)\n",         "          )\n", -       "          (1): Dropout(p=0.0, inplace=False)\n", -       "          (2): Linear(in_features=1024, out_features=256, bias=True)\n",         "        )\n", +       "        (2): Residual()\n",         "      )\n", -       "      (2): Residual()\n", -       "    )\n", -       "    (3): ModuleList(\n", -       "      (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", -       "      (1): Attention(\n", -       "        (qkv_fn): Sequential(\n", -       "          (0): Linear(in_features=256, out_features=49152, bias=False)\n", -       "          (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", +       "      (3): ModuleList(\n", +       "        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", +       "        (1): Attention(\n", +       "          (qkv_fn): Sequential(\n", +       "            (0): Linear(in_features=128, out_features=24576, bias=False)\n", +       "            (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", +       "          )\n", +       "          (dropout): Dropout(p=0.0, inplace=False)\n", +       "          (fc): Linear(in_features=8192, out_features=128, bias=True)\n",         "        )\n", -       "        (dropout): Dropout(p=0.0, inplace=False)\n", -       "        (fc): Linear(in_features=16384, out_features=256, bias=True)\n", +       "        (2): Residual()\n",         "      )\n", -       "      (2): Residual()\n", -       "    )\n", -       "    (4): ModuleList(\n", -       "      (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", -       "      (1): Attention(\n", -       "        (qkv_fn): Sequential(\n", -       "          (0): Linear(in_features=256, out_features=49152, bias=False)\n", -       "          (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", +       "      (4): ModuleList(\n", +       "        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", +       "        (1): Attention(\n", +       "          (qkv_fn): Sequential(\n", +       "            (0): Linear(in_features=128, out_features=24576, bias=False)\n", +       "            (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", +       "          )\n", +       "          (dropout): Dropout(p=0.0, inplace=False)\n", +       "          (fc): Linear(in_features=8192, out_features=128, bias=True)\n",         "        )\n", -       "        (dropout): Dropout(p=0.0, inplace=False)\n", -       "        (fc): Linear(in_features=16384, out_features=256, bias=True)\n", +       "        (2): Residual()\n",         "      )\n", -       "      (2): Residual()\n", -       "    )\n", -       "    (5): ModuleList(\n", -       "      (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", -       "      (1): FeedForward(\n", -       "        (mlp): Sequential(\n", -       "          (0): GEGLU(\n", -       "            (fc): Linear(in_features=256, out_features=2048, bias=True)\n", +       "      (5): ModuleList(\n", +       "        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", +       "        (1): FeedForward(\n", +       "          (mlp): Sequential(\n", +       "            (0): GEGLU(\n", +       "              (fc): Linear(in_features=128, out_features=1024, bias=True)\n", +       "            )\n", +       "            (1): Dropout(p=0.0, inplace=False)\n", +       "            (2): Linear(in_features=512, out_features=128, bias=True)\n",         "          )\n", -       "          (1): Dropout(p=0.0, inplace=False)\n", -       "          (2): Linear(in_features=1024, out_features=256, bias=True)\n",         "        )\n", +       "        (2): Residual()\n",         "      )\n", -       "      (2): Residual()\n", -       "    )\n", -       "    (6): ModuleList(\n", -       "      (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", -       "      (1): Attention(\n", -       "        (qkv_fn): Sequential(\n", -       "          (0): Linear(in_features=256, out_features=49152, bias=False)\n", -       "          (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", +       "      (6): ModuleList(\n", +       "        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", +       "        (1): Attention(\n", +       "          (qkv_fn): Sequential(\n", +       "            (0): Linear(in_features=128, out_features=24576, bias=False)\n", +       "            (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", +       "          )\n", +       "          (dropout): Dropout(p=0.0, inplace=False)\n", +       "          (fc): Linear(in_features=8192, out_features=128, bias=True)\n",         "        )\n", -       "        (dropout): Dropout(p=0.0, inplace=False)\n", -       "        (fc): Linear(in_features=16384, out_features=256, bias=True)\n", +       "        (2): Residual()\n",         "      )\n", -       "      (2): Residual()\n", -       "    )\n", -       "    (7): ModuleList(\n", -       "      (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", -       "      (1): Attention(\n", -       "        (qkv_fn): Sequential(\n", -       "          (0): Linear(in_features=256, out_features=49152, bias=False)\n", -       "          (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", +       "      (7): ModuleList(\n", +       "        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", +       "        (1): Attention(\n", +       "          (qkv_fn): Sequential(\n", +       "            (0): Linear(in_features=128, out_features=24576, bias=False)\n", +       "            (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", +       "          )\n", +       "          (dropout): Dropout(p=0.0, inplace=False)\n", +       "          (fc): Linear(in_features=8192, out_features=128, bias=True)\n",         "        )\n", -       "        (dropout): Dropout(p=0.0, inplace=False)\n", -       "        (fc): Linear(in_features=16384, out_features=256, bias=True)\n", +       "        (2): Residual()\n",         "      )\n", -       "      (2): Residual()\n", -       "    )\n", -       "    (8): ModuleList(\n", -       "      (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", -       "      (1): FeedForward(\n", -       "        (mlp): Sequential(\n", -       "          (0): GEGLU(\n", -       "            (fc): Linear(in_features=256, out_features=2048, bias=True)\n", +       "      (8): ModuleList(\n", +       "        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", +       "        (1): FeedForward(\n", +       "          (mlp): Sequential(\n", +       "            (0): GEGLU(\n", +       "              (fc): Linear(in_features=128, out_features=1024, bias=True)\n", +       "            )\n", +       "            (1): Dropout(p=0.0, inplace=False)\n", +       "            (2): Linear(in_features=512, out_features=128, bias=True)\n",         "          )\n", -       "          (1): Dropout(p=0.0, inplace=False)\n", -       "          (2): Linear(in_features=1024, out_features=256, bias=True)\n",         "        )\n", +       "        (2): Residual()\n",         "      )\n", -       "      (2): Residual()\n", -       "    )\n", -       "    (9): ModuleList(\n", -       "      (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", -       "      (1): Attention(\n", -       "        (qkv_fn): Sequential(\n", -       "          (0): Linear(in_features=256, out_features=49152, bias=False)\n", -       "          (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", +       "      (9): ModuleList(\n", +       "        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", +       "        (1): Attention(\n", +       "          (qkv_fn): Sequential(\n", +       "            (0): Linear(in_features=128, out_features=24576, bias=False)\n", +       "            (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", +       "          )\n", +       "          (dropout): Dropout(p=0.0, inplace=False)\n", +       "          (fc): Linear(in_features=8192, out_features=128, bias=True)\n",         "        )\n", -       "        (dropout): Dropout(p=0.0, inplace=False)\n", -       "        (fc): Linear(in_features=16384, out_features=256, bias=True)\n", +       "        (2): Residual()\n",         "      )\n", -       "      (2): Residual()\n", -       "    )\n", -       "    (10): ModuleList(\n", -       "      (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", -       "      (1): Attention(\n", -       "        (qkv_fn): Sequential(\n", -       "          (0): Linear(in_features=256, out_features=49152, bias=False)\n", -       "          (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", +       "      (10): ModuleList(\n", +       "        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", +       "        (1): Attention(\n", +       "          (qkv_fn): Sequential(\n", +       "            (0): Linear(in_features=128, out_features=24576, bias=False)\n", +       "            (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=8)\n", +       "          )\n", +       "          (dropout): Dropout(p=0.0, inplace=False)\n", +       "          (fc): Linear(in_features=8192, out_features=128, bias=True)\n",         "        )\n", -       "        (dropout): Dropout(p=0.0, inplace=False)\n", -       "        (fc): Linear(in_features=16384, out_features=256, bias=True)\n", +       "        (2): Residual()\n",         "      )\n", -       "      (2): Residual()\n", -       "    )\n", -       "    (11): ModuleList(\n", -       "      (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", -       "      (1): FeedForward(\n", -       "        (mlp): Sequential(\n", -       "          (0): GEGLU(\n", -       "            (fc): Linear(in_features=256, out_features=2048, bias=True)\n", +       "      (11): ModuleList(\n", +       "        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", +       "        (1): FeedForward(\n", +       "          (mlp): Sequential(\n", +       "            (0): GEGLU(\n", +       "              (fc): Linear(in_features=128, out_features=1024, bias=True)\n", +       "            )\n", +       "            (1): Dropout(p=0.0, inplace=False)\n", +       "            (2): Linear(in_features=512, out_features=128, bias=True)\n",         "          )\n", -       "          (1): Dropout(p=0.0, inplace=False)\n", -       "          (2): Linear(in_features=1024, out_features=256, bias=True)\n",         "        )\n", +       "        (2): Residual()\n",         "      )\n", -       "      (2): Residual()\n",         "    )\n",         "  )\n", +       "  (token_emb): Embedding(90, 128)\n", +       "  (emb_dropout): Dropout(p=0.1, inplace=False)\n", +       "  (pos_emb): AbsolutePositionalEmbedding(\n", +       "    (emb): Embedding(690, 128)\n", +       "  )\n", +       "  (project_emb): Identity()\n", +       "  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", +       "  (logits): Linear(in_features=128, out_features=90, bias=True)\n",         ")"        ]       }, -     "execution_count": 14, +     "execution_count": 7,       "metadata": {},       "output_type": "execute_result"      }     ],     "source": [ -    "decoder.cuda()" +    "transformer_decoder.cuda()"     ]    },    {     "cell_type": "code", -   "execution_count": null, +   "execution_count": 8,     "metadata": {},     "outputs": [],     "source": [ @@ -252,16 +280,7 @@    },    {     "cell_type": "code", -   "execution_count": null, -   "metadata": {}, -   "outputs": [], -   "source": [ -    "from functools import partial" -   ] -  }, -  { -   "cell_type": "code", -   "execution_count": null, +   "execution_count": 9,     "metadata": {},     "outputs": [],     "source": [ @@ -275,7 +294,7 @@    },    {     "cell_type": "code", -   "execution_count": null, +   "execution_count": 10,     "metadata": {},     "outputs": [],     "source": [ @@ -284,7 +303,7 @@    },    {     "cell_type": "code", -   "execution_count": null, +   "execution_count": 11,     "metadata": {},     "outputs": [],     "source": [ @@ -298,29 +317,95 @@    },    {     "cell_type": "code", -   "execution_count": null, +   "execution_count": 12,     "metadata": {},     "outputs": [],     "source": [ -    "t = torch.randn(32, 1, 576, 640).cuda()" +    "t = torch.randn(16, 1, 576, 640).cuda()"     ]    },    {     "cell_type": "code", -   "execution_count": null, +   "execution_count": 13,     "metadata": {},     "outputs": [],     "source": [ -    "v(t).shape" +    "o = v(t)"     ]    },    {     "cell_type": "code", -   "execution_count": null, +   "execution_count": 22,     "metadata": {},     "outputs": [],     "source": [ -    "from text_recognizer.networks.encoders.efficientnet import EfficientNet" +    "caption = torch.randint(0, 90, (16, 690)).cuda()" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 23, +   "metadata": {}, +   "outputs": [ +    { +     "data": { +      "text/plain": [ +       "torch.Size([16, 360, 128])" +      ] +     }, +     "execution_count": 23, +     "metadata": {}, +     "output_type": "execute_result" +    } +   ], +   "source": [ +    "o.shape" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 24, +   "metadata": {}, +   "outputs": [ +    { +     "data": { +      "text/plain": [ +       "torch.Size([16, 690])" +      ] +     }, +     "execution_count": 24, +     "metadata": {}, +     "output_type": "execute_result" +    } +   ], +   "source": [ +    "caption.shape" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 29, +   "metadata": {}, +   "outputs": [ +    { +     "ename": "TypeError", +     "evalue": "forward() missing 2 required positional arguments: 'context' and 'context_mask'", +     "output_type": "error", +     "traceback": [ +      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", +      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)", +      "\u001b[0;32m<ipython-input-29-2290911ad81b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtransformer_decoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcaption\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# (1, 1024, 20000)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", +      "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    887\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    888\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    890\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    891\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", +      "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/transformer.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, mask, return_embeddings, **kwargs)\u001b[0m\n\u001b[1;32m     58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     59\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproject_emb\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 60\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattn_layers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     61\u001b[0m         \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mreturn_embeddings\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     62\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", +      "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    887\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    888\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    890\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    891\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", +      "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/layers.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, context, mask, context_mask)\u001b[0m\n\u001b[1;32m     89\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     90\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mlayer_type\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"a\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 91\u001b[0;31m                 \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mblock\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrotary_pos_emb\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrotary_pos_emb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     92\u001b[0m             \u001b[0;32melif\u001b[0m \u001b[0mlayer_type\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"c\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     93\u001b[0m                 \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mblock\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcontext_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", +      "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    887\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    888\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    890\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    891\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", +      "\u001b[0;31mTypeError\u001b[0m: forward() missing 2 required positional arguments: 'context' and 'context_mask'" +     ] +    } +   ], +   "source": [ +    "transformer_decoder(caption, context = o) # (1, 1024, 20000)"     ]    },    { @@ -329,7 +414,7 @@     "metadata": {},     "outputs": [],     "source": [ -    "en = EfficientNet()" +    "from text_recognizer.networks.encoders.efficientnet import EfficientNet"     ]    },    { @@ -338,7 +423,7 @@     "metadata": {},     "outputs": [],     "source": [ -    "(576, 640) // (8, 8)" +    "en = EfficientNet()"     ]    },    { @@ -347,7 +432,7 @@     "metadata": {},     "outputs": [],     "source": [ -    "(576 // 32) ** 2" +    "en.cuda()"     ]    },    { @@ -482,24 +567,6 @@     "metadata": {},     "outputs": [],     "source": [ -    "from text_recognizer.networks.backbones.efficientnet import EfficientNet" -   ] -  }, -  { -   "cell_type": "code", -   "execution_count": null, -   "metadata": {}, -   "outputs": [], -   "source": [ -    "en = EfficientNet()" -   ] -  }, -  { -   "cell_type": "code", -   "execution_count": null, -   "metadata": {}, -   "outputs": [], -   "source": [      "datum = torch.randn([2, 1, 576, 640])"     ]    }, @@ -536,7 +603,7 @@     "metadata": {},     "outputs": [],     "source": [ -    "en(datum).shape" +    "en(t).shape"     ]    },    { @@ -752,7 +819,7 @@     "name": "python",     "nbconvert_exporter": "python",     "pygments_lexer": "ipython3", -   "version": "3.9.1" +   "version": "3.9.5"    }   },   "nbformat": 4, diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb index b72019b..e05704d 100644 --- a/notebooks/03-look-at-iam-paragraphs.ipynb +++ b/notebooks/03-look-at-iam-paragraphs.ipynb @@ -505,7 +505,7 @@     "name": "python",     "nbconvert_exporter": "python",     "pygments_lexer": "ipython3", -   "version": "3.9.1" +   "version": "3.9.5"    }   },   "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index 33c7287..c3f6ebc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@  name = "text-recognizer"  version = "0.1.0"  authors = ["aktersnurra <gustaf.rydholm@gmail.com>"] -description = "A text recognizer using best pratices in python and deep learning." +description = "Text recognition software using best pratices in python."  license = "MIT"  readme = "README.md"  homepage = "https://github.com/aktersnurra/text-recognizer" diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py deleted file mode 100644 index 80798e1..0000000 --- a/text_recognizer/networks/cnn_transformer.py +++ /dev/null @@ -1,182 +0,0 @@ -# """A Transformer with a cnn backbone. -# -# The network encodes a image with a convolutional backbone to a latent representation, -# i.e. feature maps. A 2d positional encoding is applied to the feature maps for -# spatial information. The resulting feature are then set to a transformer decoder -# together with the target tokens. -# -# TODO: Local attention for lower layer in attention. -# -# """ -# import importlib -# import math -# from typing import Dict, Optional, Union, Sequence, Type -# -# from einops import rearrange -# from omegaconf import DictConfig, OmegaConf -# import torch -# from torch import nn -# from torch import Tensor -# -# from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS -# from text_recognizer.networks.transformer import ( -#     Decoder, -#     DecoderLayer, -#     PositionalEncoding, -#     PositionalEncoding2D, -#     target_padding_mask, -# ) -# -# NUM_WORD_PIECES = 1000 -# -# -# class CNNTransformer(nn.Module): -#     def __init__( -#         self, -#         input_dim: Sequence[int], -#         output_dims: Sequence[int], -#         encoder: Union[DictConfig, Dict], -#         vocab_size: Optional[int] = None, -#         num_decoder_layers: int = 4, -#         hidden_dim: int = 256, -#         num_heads: int = 4, -#         expansion_dim: int = 1024, -#         dropout_rate: float = 0.1, -#         transformer_activation: str = "glu", -#         *args, -#         **kwargs, -#     ) -> None: -#         super().__init__() -#         self.vocab_size = ( -#             NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size -#         ) -#         self.pad_index = 3  # TODO: fix me -#         self.hidden_dim = hidden_dim -#         self.max_output_length = output_dims[0] -# -#         # Image backbone -#         self.encoder = self._configure_encoder(encoder) -#         self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1) -#         self.feature_map_encoding = PositionalEncoding2D( -#             hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2] -#         ) -# -#         # Target token embedding -#         self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) -#         self.trg_position_encoding = PositionalEncoding( -#             hidden_dim, dropout_rate, max_len=output_dims[0] -#         ) -# -#         # Transformer decoder -#         self.decoder = Decoder( -#             decoder_layer=DecoderLayer( -#                 hidden_dim=hidden_dim, -#                 num_heads=num_heads, -#                 expansion_dim=expansion_dim, -#                 dropout_rate=dropout_rate, -#                 activation=transformer_activation, -#             ), -#             num_layers=num_decoder_layers, -#             norm=nn.LayerNorm(hidden_dim), -#         ) -# -#         # Classification head -#         self.head = nn.Linear(hidden_dim, self.vocab_size) -# -#         # Initialize weights -#         self._init_weights() -# -#     def _init_weights(self) -> None: -#         """Initialize network weights.""" -#         self.trg_embedding.weight.data.uniform_(-0.1, 0.1) -#         self.head.bias.data.zero_() -#         self.head.weight.data.uniform_(-0.1, 0.1) -# -#         nn.init.kaiming_normal_( -#             self.encoder_proj.weight.data, -#             a=0, -#             mode="fan_out", -#             nonlinearity="relu", -#         ) -#         if self.encoder_proj.bias is not None: -#             _, fan_out = nn.init._calculate_fan_in_and_fan_out( -#                 self.encoder_proj.weight.data -#             ) -#             bound = 1 / math.sqrt(fan_out) -#             nn.init.normal_(self.encoder_proj.bias, -bound, bound) -# -#     @staticmethod -#     def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: -#         encoder = OmegaConf.create(encoder) -#         args = encoder.args or {} -#         network_module = importlib.import_module("text_recognizer.networks") -#         encoder_class = getattr(network_module, encoder.type) -#         return encoder_class(**args) -# -#     def encode(self, image: Tensor) -> Tensor: -#         """Extracts image features with backbone. -# -#         Args: -#             image (Tensor): Image(s) of handwritten text. -# -#         Retuns: -#             Tensor: Image features. -# -#         Shapes: -#             - image: :math: `(B, C, H, W)` -#             - latent: :math: `(B, T, C)` -# -#         """ -#         # Extract image features. -#         image_features = self.encoder(image) -#         image_features = self.encoder_proj(image_features) -# -#         # Add 2d encoding to the feature maps. -#         image_features = self.feature_map_encoding(image_features) -# -#         # Collapse features maps height and width. -#         image_features = rearrange(image_features, "b c h w -> b (h w) c") -#         return image_features -# -#     def decode(self, memory: Tensor, trg: Tensor) -> Tensor: -#         """Decodes image features with transformer decoder.""" -#         trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) -#         trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) -#         trg = rearrange(trg, "b t d -> t b d") -#         trg = self.trg_position_encoding(trg) -#         trg = rearrange(trg, "t b d -> b t d") -#         out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) -#         logits = self.head(out) -#         return logits -# -#     def forward(self, image: Tensor, trg: Tensor) -> Tensor: -#         image_features = self.encode(image) -#         output = self.decode(image_features, trg) -#         output = rearrange(output, "b t c -> b c t") -#         return output -# -#     def predict(self, image: Tensor) -> Tensor: -#         """Transcribes text in image(s).""" -#         bsz = image.shape[0] -#         image_features = self.encode(image) -# -#         output_tokens = ( -#             (torch.ones((bsz, self.max_output_length)) * self.pad_index) -#             .type_as(image) -#             .long() -#         ) -#         output_tokens[:, 0] = self.start_index -#         for i in range(1, self.max_output_length): -#             trg = output_tokens[:, :i] -#             output = self.decode(image_features, trg) -#             output = torch.argmax(output, dim=-1) -#             output_tokens[:, i] = output[-1:] -# -#         # Set all tokens after end token to be padding. -#         for i in range(1, self.max_output_length): -#             indices = output_tokens[:, i - 1] == self.end_index | ( -#                 output_tokens[:, i - 1] == self.pad_index -#             ) -#             output_tokens[indices, i] = self.pad_index -# -#         return output_tokens diff --git a/text_recognizer/networks/transducer/__init__.py b/text_recognizer/networks/transducer/__init__.py deleted file mode 100644 index 8c19a01..0000000 --- a/text_recognizer/networks/transducer/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Transducer modules.""" -from .tds_conv import TDS2d -from .transducer import load_transducer_loss, Transducer diff --git a/text_recognizer/networks/transducer/tds_conv.py b/text_recognizer/networks/transducer/tds_conv.py deleted file mode 100644 index 5fb8ba9..0000000 --- a/text_recognizer/networks/transducer/tds_conv.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Time-Depth Separable Convolutions. - -References: -    https://arxiv.org/abs/1904.02619 -    https://arxiv.org/pdf/2010.01003.pdf - -Code stolen from: -    https://github.com/facebookresearch/gtn_applications - - -""" -from typing import List, Tuple - -from einops import rearrange -import gtn -import numpy as np -import torch -from torch import nn -from torch import Tensor - - -class TDSBlock2d(nn.Module): -    """Internal block of a 2D TDSC network.""" - -    def __init__( -        self, -        in_channels: int, -        img_depth: int, -        kernel_size: Tuple[int], -        dropout_rate: float, -    ) -> None: -        super().__init__() - -        self.in_channels = in_channels -        self.img_depth = img_depth -        self.kernel_size = kernel_size -        self.dropout_rate = dropout_rate -        self.fc_dim = in_channels * img_depth - -        # Network placeholders. -        self.conv = None -        self.mlp = None -        self.instance_norm = None - -        self._build_block() - -    def _build_block(self) -> None: -        # Convolutional block. -        self.conv = nn.Sequential( -            nn.Conv3d( -                in_channels=self.in_channels, -                out_channels=self.in_channels, -                kernel_size=(1, self.kernel_size[0], self.kernel_size[1]), -                padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2), -            ), -            nn.ReLU(inplace=True), -            nn.Dropout(self.dropout_rate), -        ) - -        # MLP block. -        self.mlp = nn.Sequential( -            nn.Linear(self.fc_dim, self.fc_dim), -            nn.ReLU(inplace=True), -            nn.Dropout(self.dropout_rate), -            nn.Linear(self.fc_dim, self.fc_dim), -            nn.Dropout(self.dropout_rate), -        ) - -        # Instance norm. -        self.instance_norm = nn.ModuleList( -            [ -                nn.InstanceNorm2d(self.fc_dim, affine=True), -                nn.InstanceNorm2d(self.fc_dim, affine=True), -            ] -        ) - -    def forward(self, x: Tensor) -> Tensor: -        """Forward pass. - -        Args: -            x (Tensor): Input tensor. - -        Shape: -            - x: :math: `(B, CD, H, W)` - -        Returns: -            Tensor: Output tensor. - -        """ -        B, CD, H, W = x.shape -        C, D = self.in_channels, self.img_depth -        residual = x -        x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D) -        x = self.conv(x) -        x = rearrange(x, "b c d h w -> b (c d) h w") -        x += residual - -        x = self.instance_norm[0](x) - -        x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x -        x + self.instance_norm[1](x) - -        # Output shape: [B, CD, H, W] -        return x - - -class TDS2d(nn.Module): -    """TDS Netowrk. - -    Structure is the following: -        Downsample layer -> TDS2d group -> ... -> Linear output layer - - -    """ - -    def __init__( -        self, -        input_dim: int, -        output_dim: int, -        depth: int, -        tds_groups: Tuple[int], -        kernel_size: Tuple[int], -        dropout_rate: float, -        in_channels: int = 1, -    ) -> None: -        super().__init__() - -        self.in_channels = in_channels -        self.input_dim = input_dim -        self.output_dim = output_dim -        self.depth = depth -        self.tds_groups = tds_groups -        self.kernel_size = kernel_size -        self.dropout_rate = dropout_rate - -        self.tds = None -        self.fc = None - -        self._build_network() - -    def _build_network(self) -> None: -        in_channels = self.in_channels -        modules = [] -        stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups]) -        if self.input_dim % stride_h: -            raise RuntimeError( -                f"Image height not divisible by total stride {stride_h}." -            ) - -        for tds_group in self.tds_groups: -            # Add downsample layer. -            out_channels = self.depth * tds_group["channels"] -            modules.extend( -                [ -                    nn.Conv2d( -                        in_channels=in_channels, -                        out_channels=out_channels, -                        kernel_size=self.kernel_size, -                        padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2), -                        stride=tds_group["stride"], -                    ), -                    nn.ReLU(inplace=True), -                    nn.Dropout(self.dropout_rate), -                    nn.InstanceNorm2d(out_channels, affine=True), -                ] -            ) - -            for _ in range(tds_group["num_blocks"]): -                modules.append( -                    TDSBlock2d( -                        tds_group["channels"], -                        self.depth, -                        self.kernel_size, -                        self.dropout_rate, -                    ) -                ) - -            in_channels = out_channels - -        self.tds = nn.Sequential(*modules) -        self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim) - -    def forward(self, x: Tensor) -> Tensor: -        """Forward pass. - -        Args: -            x (Tensor): Input tensor. - -        Shape: -            - x: :math: `(B, H, W)` - -        Returns: -            Tensor: Output tensor. - -        """ -        if len(x.shape) == 4: -            x = x.squeeze(1)  # Squeeze the channel dim away. - -        B, H, W = x.shape -        x = rearrange( -            x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels -        ) -        x = self.tds(x) - -        # x shape: [B, C, H, W] -        x = rearrange(x, "b c h w -> b w (c h)") - -        return self.fc(x) diff --git a/text_recognizer/networks/transducer/test.py b/text_recognizer/networks/transducer/test.py deleted file mode 100644 index cadcecc..0000000 --- a/text_recognizer/networks/transducer/test.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -from torch import nn - -from text_recognizer.networks.transducer import load_transducer_loss, Transducer -import unittest - - -class TestTransducer(unittest.TestCase): -    def test_viterbi(self): -        T = 5 -        N = 4 -        B = 2 - -        # fmt: off -        emissions1 = torch.tensor(( -            0, 4, 0, 1, -            0, 2, 1, 1, -            0, 0, 0, 2, -            0, 0, 0, 2, -            8, 0, 0, 2, -            ), -            dtype=torch.float, -        ).view(T, N) -        emissions2 = torch.tensor(( -            0, 2, 1, 7, -            0, 2, 9, 1, -            0, 0, 0, 2, -            0, 0, 5, 2, -            1, 0, 0, 2, -            ), -            dtype=torch.float, -        ).view(T, N) -        # fmt: on - -        # Test without blank: -        labels = [[1, 3, 0], [3, 2, 3, 2, 3]] -        transducer = Transducer( -            tokens=["a", "b", "c", "d"], -            graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3}, -            blank="none", -        ) -        emissions = torch.stack([emissions1, emissions2], dim=0) -        predictions = transducer.viterbi(emissions) -        self.assertEqual([p.tolist() for p in predictions], labels) - -        # Test with blank without repeats: -        labels = [[1, 0], [2, 2]] -        transducer = Transducer( -            tokens=["a", "b", "c"], -            graphemes_to_idx={"a": 0, "b": 1, "c": 2}, -            blank="optional", -            allow_repeats=False, -        ) -        emissions = torch.stack([emissions1, emissions2], dim=0) -        predictions = transducer.viterbi(emissions) -        self.assertEqual([p.tolist() for p in predictions], labels) - - -if __name__ == "__main__": -    unittest.main() diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py deleted file mode 100644 index d7e3d08..0000000 --- a/text_recognizer/networks/transducer/transducer.py +++ /dev/null @@ -1,410 +0,0 @@ -"""Transducer and the transducer loss function.py - -Stolen from: -    https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py - -""" -from pathlib import Path -import itertools -from typing import Dict, List, Optional, Union, Tuple - -from loguru import logger -import gtn -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.datasets.iam_preprocessor import Preprocessor - - -def make_scalar_graph(weight) -> gtn.Graph: -    scalar = gtn.Graph() -    scalar.add_node(True) -    scalar.add_node(False, True) -    scalar.add_arc(0, 1, 0, 0, weight) -    return scalar - - -def make_chain_graph(sequence) -> gtn.Graph: -    graph = gtn.Graph(False) -    graph.add_node(True) -    for i, s in enumerate(sequence): -        graph.add_node(False, i == (len(sequence) - 1)) -        graph.add_arc(i, i + 1, s) -    return graph - - -def make_transitions_graph( -    ngram: int, num_tokens: int, calc_grad: bool = False -) -> gtn.Graph: -    transitions = gtn.Graph(calc_grad) -    transitions.add_node(True, ngram == 1) - -    state_map = {(): 0} - -    # First build transitions which include <s>: -    for n in range(1, ngram): -        for state in itertools.product(range(num_tokens), repeat=n): -            in_idx = state_map[state[:-1]] -            out_idx = transitions.add_node(False, ngram == 1) -            state_map[state] = out_idx -            transitions.add_arc(in_idx, out_idx, state[-1]) - -    for state in itertools.product(range(num_tokens), repeat=ngram): -        state_idx = state_map[state[:-1]] -        new_state_idx = state_map[state[1:]] -        # p(state[-1] | state[:-1]) -        transitions.add_arc(state_idx, new_state_idx, state[-1]) - -    if ngram > 1: -        # Build transitions which include </s>: -        end_idx = transitions.add_node(False, True) -        for in_idx in range(end_idx): -            transitions.add_arc(in_idx, end_idx, gtn.epsilon) - -    return transitions - - -def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph: -    """Constructs a graph which transduces letters to word pieces.""" -    graph = gtn.Graph(False) -    graph.add_node(True, True) -    for i, wp in enumerate(word_pieces): -        prev = 0 -        for l in wp[:-1]: -            n = graph.add_node() -            graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon) -            prev = n -        graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i) -    graph.arc_sort() -    return graph - - -def make_token_graph( -    token_list: List, blank: str = "none", allow_repeats: bool = True -) -> gtn.Graph: -    """Constructs a graph with all the individual token transition models.""" -    if not allow_repeats and blank != "optional": -        raise ValueError("Must use blank='optional' if disallowing repeats.") - -    ntoks = len(token_list) -    graph = gtn.Graph(False) - -    # Creating nodes -    graph.add_node(True, True) -    for i in range(ntoks): -        # We can consume one or more consecutive word -        # pieces for each emission: -        # E.g. [ab, ab, ab] transduces to [ab] -        graph.add_node(False, blank != "forced") - -    if blank != "none": -        graph.add_node() - -    # Creating arcs -    if blank != "none": -        # Blank index is assumed to be last (ntoks) -        graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon) -        graph.add_arc(ntoks + 1, 0, gtn.epsilon) - -    for i in range(ntoks): -        graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i) -        graph.add_arc(i + 1, i + 1, i, gtn.epsilon) - -        if allow_repeats: -            if blank == "forced": -                # Allow transitions from token to blank only -                graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) -            else: -                # Allow transition from token to blank and all other tokens -                graph.add_arc(i + 1, 0, gtn.epsilon) - -        else: -            # allow transitions to blank and all other tokens except the same token -            graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) -            for j in range(ntoks): -                if i != j: -                    graph.add_arc(i + 1, j + 1, j, j) - -    return graph - - -class TransducerLossFunction(torch.autograd.Function): -    @staticmethod -    def forward( -        ctx, -        inputs, -        targets, -        tokens, -        lexicon, -        transition_params=None, -        transitions=None, -        reduction="none", -    ) -> Tensor: -        B, T, C = inputs.shape - -        losses = [None] * B -        emissions_graphs = [None] * B - -        if transitions is not None: -            if transition_params is None: -                raise ValueError("Specified transitions, but not transition params.") - -            cpu_data = transition_params.cpu().contiguous() -            transitions.set_weights(cpu_data.data_ptr()) -            transitions.calc_grad = transition_params.requires_grad -            transitions.zero_grad() - -        def process(b: int) -> None: -            # Create emission graph: -            emissions = gtn.linear_graph(T, C, inputs.requires_grad) -            cpu_data = inputs[b].cpu().contiguous() -            emissions.set_weights(cpu_data.data_ptr()) -            target = make_chain_graph(targets[b]) -            target.arc_sort(True) - -            # Create token tot grapheme decomposition graph -            tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon))) -            tokens_target.arc_sort() - -            # Create alignment graph: -            aligments = gtn.project_input( -                gtn.remove(gtn.compose(tokens, tokens_target)) -            ) -            aligments.arc_sort() - -            # Add transitions scores: -            if transitions is not None: -                aligments = gtn.intersect(transitions, aligments) -                aligments.arc_sort() - -            loss = gtn.forward_score(gtn.intersect(emissions, aligments)) - -            # Normalize if needed: -            if transitions is not None: -                norm = gtn.forward_score(gtn.intersect(emissions, transitions)) -                loss = gtn.subtract(loss, norm) - -            losses[b] = gtn.negate(loss) - -            # Save for backward: -            if emissions.calc_grad: -                emissions_graphs[b] = emissions - -        gtn.parallel_for(process, range(B)) - -        ctx.graphs = (losses, emissions_graphs, transitions) -        ctx.input_shape = inputs.shape - -        # Optionally reduce by target length -        if reduction == "mean": -            scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets] -        else: -            scales = [1.0] * B - -        ctx.scales = scales - -        loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)]) -        return torch.mean(loss.to(inputs.device)) - -    @staticmethod -    def backward(ctx, grad_output) -> Tuple: -        losses, emissions_graphs, transitions = ctx.graphs -        scales = ctx.scales - -        B, T, C = ctx.input_shape -        calc_emissions = ctx.needs_input_grad[0] -        input_grad = torch.empty((B, T, C)) if calc_emissions else None - -        def process(b: int) -> None: -            scale = make_scalar_graph(scales[b]) -            gtn.backward(losses[b], scale) -            emissions = emissions_graphs[b] -            if calc_emissions: -                grad = emissions.grad().weights_to_numpy() -                input_grad[b] = torch.tensor(grad).view(1, T, C) - -        gtn.parallel_for(process, range(B)) - -        if calc_emissions: -            input_grad = input_grad.to(grad_output.device) -            input_grad *= grad_output / B - -        if ctx.needs_input_grad[4]: -            grad = transitions.grad().weights_to_numpy() -            transition_grad = torch.tensor(grad).to(grad_output.device) -            transition_grad *= grad_output / B -        else: -            transition_grad = None - -        return ( -            input_grad, -            None,  # target -            None,  # tokens -            None,  # lexicon -            transition_grad,  # transition params -            None,  # transitions graph -            None, -        ) - - -TransducerLoss = TransducerLossFunction.apply - - -class Transducer(nn.Module): -    def __init__( -        self, -        tokens: List, -        graphemes_to_idx: Dict, -        ngram: int = 0, -        transitions: str = None, -        blank: str = "none", -        allow_repeats: bool = True, -        reduction: str = "none", -    ) -> None: -        """A generic transducer loss function. - -        Args: -            tokens (List) : A list of iterable objects (e.g. strings, tuples, etc) -                representing the output tokens of the model (e.g. letters, -                word-pieces, words). For example ["a", "b", "ab", "ba", "aba"] -                could be a list of sub-word tokens. -            graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g. -                "a", "b", ..) to their corresponding integer index. -            ngram (int) : Order of the token-level transition model. If `ngram=0` -                then no transition model is used. -            blank (string) : Specifies the usage of blank token -                'none' - do not use blank token -                'optional' - allow an optional blank inbetween tokens -                'forced' - force a blank inbetween tokens (also referred to as garbage token) -            allow_repeats (boolean) : If false, then we don't allow paths with -                consecutive tokens in the alignment graph. This keeps the graph -                unambiguous in the sense that the same input cannot transduce to -                different outputs. -        """ -        super().__init__() -        if blank not in ["optional", "forced", "none"]: -            raise ValueError( -                "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']" -            ) -        self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats) -        self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx) -        self.ngram = ngram -        if ngram > 0 and transitions is not None: -            raise ValueError("Only one of ngram and transitions may be specified") - -        if ngram > 0: -            transitions = make_transitions_graph( -                ngram, len(tokens) + int(blank != "none"), True -            ) - -        if transitions is not None: -            self.transitions = transitions -            self.transitions.arc_sort() -            self.transitions_params = nn.Parameter( -                torch.zeros(self.transitions.num_arcs()) -            ) -        else: -            self.transitions = None -            self.transitions_params = None -        self.reduction = reduction - -    def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss: -        TransducerLoss( -            inputs, -            targets, -            self.tokens, -            self.lexicon, -            self.transitions_params, -            self.transitions, -            self.reduction, -        ) - -    def viterbi(self, outputs: Tensor) -> List[Tensor]: -        B, T, C = outputs.shape - -        if self.transitions is not None: -            cpu_data = self.transition_params.cpu().contiguous() -            self.transitions.set_weights(cpu_data.data_ptr()) -            self.transitions.calc_grad = False - -        self.tokens.arc_sort() - -        paths = [None] * B - -        def process(b: int) -> None: -            emissions = gtn.linear_graph(T, C, False) -            cpu_data = outputs[b].cpu().contiguous() -            emissions.set_weights(cpu_data.data_ptr()) - -            if self.transitions is not None: -                full_graph = gtn.intersect(emissions, self.transitions) -            else: -                full_graph = emissions - -            # Find the best path and remove back-off arcs: -            path = gtn.remove(gtn.viterbi_path(full_graph)) - -            # Left compose the viterbi path with the "aligment to token" -            # transducer to get the outputs: -            path = gtn.compose(path, self.tokens) - -            # When there are ambiguous paths (allow_repeats is true), we take -            # the shortest: -            path = gtn.viterbi_path(path) -            path = gtn.remove(gtn.project_output(path)) -            paths[b] = path.labels_to_list() - -        gtn.parallel_for(process, range(B)) -        predictions = [torch.IntTensor(path) for path in paths] -        return predictions - - -def load_transducer_loss( -    num_features: int, -    ngram: int, -    tokens: str, -    lexicon: str, -    transitions: str, -    blank: str, -    allow_repeats: bool, -    prepend_wordsep: bool = False, -    use_words: bool = False, -    data_dir: Optional[Union[str, Path]] = None, -    reduction: str = "mean", -) -> Tuple[Transducer, int]: -    if data_dir is None: -        data_dir = ( -            Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb" -        ) -        logger.debug(f"Using data dir: {data_dir}") -        if not data_dir.exists(): -            raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") -    else: -        data_dir = Path(data_dir) -    processed_path = ( -        Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines" -    ) -    tokens_path = processed_path / tokens -    lexicon_path = processed_path / lexicon - -    if transitions is not None: -        transitions = gtn.load(str(processed_path / transitions)) - -    preprocessor = Preprocessor( -        data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, -    ) - -    num_tokens = preprocessor.num_tokens - -    criterion = Transducer( -        preprocessor.tokens, -        preprocessor.graphemes_to_index, -        ngram=ngram, -        transitions=transitions, -        blank=blank, -        allow_repeats=allow_repeats, -        reduction=reduction, -    ) - -    return criterion, num_tokens + int(blank != "none") diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py index a3f3011..d9e63ef 100644 --- a/text_recognizer/networks/transformer/__init__.py +++ b/text_recognizer/networks/transformer/__init__.py @@ -1 +1,3 @@  """Transformer modules.""" +from .nystromer.nystromer import Nystromer +from .vit import ViT diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index b2c703f..a44a525 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -1,8 +1,6 @@  """Generates the attention layer architecture."""  from functools import partial -from typing import Any, Dict, Optional, Type - -from click.types import Tuple +from typing import Any, Dict, Optional, Tuple, Type  from torch import nn, Tensor @@ -30,6 +28,7 @@ class AttentionLayers(nn.Module):          pre_norm: bool = True,      ) -> None:          super().__init__() +        self.dim = dim          attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs)          norm_fn = partial(norm_fn, dim)          ff_fn = partial(ff_fn, dim=dim, **ff_kwargs) diff --git a/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py b/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py index 9466f6e..7140537 100644 --- a/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py +++ b/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py @@ -1,4 +1,5 @@  """Absolute positional embedding.""" +import torch  from torch import nn, Tensor diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py index 60ab1ce..31088b4 100644 --- a/text_recognizer/networks/transformer/transformer.py +++ b/text_recognizer/networks/transformer/transformer.py @@ -19,7 +19,9 @@ class Transformer(nn.Module):          emb_dropout: float = 0.0,          use_pos_emb: bool = True,      ) -> None: +        super().__init__()          dim = attn_layers.dim +        self.attn_layers = attn_layers          emb_dim = emb_dim if emb_dim is not None else dim          self.max_seq_len = max_seq_len @@ -32,7 +34,6 @@ class Transformer(nn.Module):          )          self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() -        self.attn_layers = attn_layers          self.norm = nn.LayerNorm(dim)          self._init_weights() @@ -45,12 +46,12 @@ class Transformer(nn.Module):      def forward(          self,          x: Tensor, -        mask: Optional[Tensor], +        mask: Optional[Tensor] = None,          return_embeddings: bool = False,          **kwargs: Any      ) -> Tensor:          b, n, device = *x.shape, x.device -        x += self.token_emb(x) +        x = self.token_emb(x)          if self.pos_emb is not None:              x += self.pos_emb(x)          x = self.emb_dropout(x) diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py index 9c6b151..05b10a8 100644 --- a/text_recognizer/networks/util.py +++ b/text_recognizer/networks/util.py @@ -22,42 +22,3 @@ def activation_function(activation: str) -> Type[nn.Module]:          ]      )      return activation_fns[activation.lower()] - - -# def configure_backbone(backbone: Union[OmegaConf, NamedTuple]) -> Type[nn.Module]: -#     """Loads a backbone network.""" -#     network_module = importlib.import_module("text_recognizer.networks") -#     backbone_class = getattr(network_module, backbone.type) -# -#     if "pretrained" in backbone.args: -#         logger.info("Loading pretrained backbone.") -#         checkpoint_file = Path(__file__).resolve().parents[2] / backbone.args.pop( -#             "pretrained" -#         ) -# -#         # Loading state directory. -#         state_dict = torch.load(checkpoint_file) -#         network_args = state_dict["network_args"] -#         weights = state_dict["model_state"] -# -#         freeze = False -#         if "freeze" in backbone.args and backbone.args["freeze"] is True: -#             backbone.args.pop("freeze") -#             freeze = True -# -#         # Initializes the network with trained weights. -#         backbone_ = backbone_(**backbone.args) -#         backbone_.load_state_dict(weights) -#         if freeze: -#             for params in backbone_.parameters(): -#                 params.requires_grad = False -#     else: -#         backbone_ = getattr(network_module, backbone.type) -#         backbone_ = backbone_(**backbone.args) -# -#     if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None: -#         backbone = nn.Sequential( -#             *list(backbone.children())[:][: -backbone_args["remove_layers"]] -#         ) -# -#     return backbone |