From a1d795bf02d14befc62cf600fb48842958148eba Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Fri, 23 Jul 2021 14:55:31 +0200
Subject: Complete cnn-transformer network, not tested

---
 notebooks/00-scratch-pad.ipynb                     | 421 +++++----------------
 noxfile.py                                         |   1 -
 text_recognizer/data/iam_preprocessor.py           |   2 +
 text_recognizer/networks/cnn_tranformer.py         |  81 +++-
 .../networks/encoders/efficientnet/efficientnet.py |   2 +-
 5 files changed, 161 insertions(+), 346 deletions(-)

diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb
index 1e30038..2c98064 100644
--- a/notebooks/00-scratch-pad.ipynb
+++ b/notebooks/00-scratch-pad.ipynb
@@ -2,9 +2,18 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 5,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "The autoreload extension is already loaded. To reload it, use:\n",
+      "  %reload_ext autoreload\n"
+     ]
+    }
+   ],
    "source": [
     "%load_ext autoreload\n",
     "%autoreload 2\n",
@@ -30,472 +39,230 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from pathlib import Path"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [],
    "source": [
-    "import attr"
+    "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 44,
-   "metadata": {},
+   "execution_count": 7,
+   "metadata": {
+    "scrolled": false
+   },
    "outputs": [],
    "source": [
-    "@attr.s\n",
-    "class B(nn.Module):\n",
-    "    input_dim = attr.ib()\n",
-    "    hidden = attr.ib()\n",
-    "    xx = attr.ib(init=False, default=\"hek\")\n",
-    "    \n",
-    "    def __attrs_post_init__(self):\n",
-    "        super().__init__()\n",
-    "        self.fc = nn.Linear(self.input_dim, self.hidden)\n",
-    "        self.xx = \"da\"\n",
-    "    \n",
-    "    def forward(self, x):\n",
-    "        return self.fc(x)"
+    "en = EfficientNet(\"b0\")"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 49,
+   "execution_count": 9,
    "metadata": {},
    "outputs": [],
    "source": [
-    "def f(x):\n",
-    "    return 2\n",
-    "\n",
-    "@attr.s(auto_attribs=True)\n",
-    "class T(B):\n",
-    "    \n",
-    "    h: Path = attr.ib(converter=Path)\n",
-    "    p: int = attr.ib(init=False, default=f(3))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 53,
-   "metadata": {},
-   "outputs": [
-    {
-     "ename": "TypeError",
-     "evalue": "__init__() missing 1 required positional argument: 'hidden'",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
-      "\u001b[0;32m<ipython-input-53-ef8b390156f4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mT\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_dim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m16\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"hej\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
-      "\u001b[0;31mTypeError\u001b[0m: __init__() missing 1 required positional argument: 'hidden'"
-     ]
-    }
-   ],
-   "source": [
-    "t = T(input_dim=16, h=\"hej\")"
+    "def generate_square_subsequent_mask(size: int) -> torch.Tensor:\n",
+    "    \"\"\"Generate a triangular (size, size) mask.\"\"\"\n",
+    "    mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)\n",
+    "    mask = mask.float().masked_fill(mask == 0, float(\"-inf\")).masked_fill(mask == 1, float(0.0))\n",
+    "    return mask"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 51,
+   "execution_count": 10,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "'da'"
+       "tensor([[0., -inf, -inf, -inf],\n",
+       "        [0., 0., -inf, -inf],\n",
+       "        [0., 0., 0., -inf],\n",
+       "        [0., 0., 0., 0.]])"
       ]
      },
-     "execution_count": 51,
+     "execution_count": 10,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
-    "t.xx"
+    "generate_square_subsequent_mask(4)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from torch import Tensor"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": 52,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "2"
-      ]
-     },
-     "execution_count": 52,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
+   "outputs": [],
    "source": [
-    "t.p"
+    "tgt = torch.randint(0, 4, (1, 4))\n",
+    "tgt_mask = torch.ones_like(tgt).bool()"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 53,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "16"
+       "tensor([[True, True, True, True]])"
       ]
      },
-     "execution_count": 19,
+     "execution_count": 53,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
-    "t.input_dim"
+    "tgt_mask"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 13,
    "metadata": {},
    "outputs": [],
    "source": [
-    "x = torch.rand(16, 16)"
+    "def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor:\n",
+    "    \"\"\"Returns causal target mask.\"\"\"\n",
+    "    trg_pad_mask = (trg != pad_index)[:, None, None]\n",
+    "    trg_len = trg.shape[1]\n",
+    "    trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()\n",
+    "    trg_mask = trg_pad_mask & trg_sub_mask\n",
+    "    return trg_mask"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 54,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "torch.Size([16, 16])"
-      ]
-     },
-     "execution_count": 21,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
+   "outputs": [],
    "source": [
-    "x.shape"
+    "t = torch.randint(0, 6, (0, 4))"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 23,
+   "execution_count": 55,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "T(input_dim=16, hidden=24, h=PosixPath('hej'))"
-      ]
-     },
-     "execution_count": 23,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
+   "outputs": [],
    "source": [
-    "t.cuda()"
+    "t = torch.Tensor([[0, 0, 0, 3, 3, 3]])"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": 58,
    "metadata": {},
    "outputs": [],
    "source": [
-    "x = x.cuda()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 25,
-   "metadata": {
-    "scrolled": true
-   },
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[ 3.6047e-01,  1.0200e+00,  3.6786e-01,  1.6077e-01,  3.9281e-02,\n",
-       "          3.2830e-01,  1.3433e-01, -9.0334e-02, -3.8712e-01,  8.1547e-01,\n",
-       "         -5.4483e-01, -9.7471e-01,  3.3706e-01, -9.5283e-01, -1.6271e-01,\n",
-       "          3.8504e-01, -5.0106e-01, -4.8638e-01,  3.7033e-01, -4.9557e-01,\n",
-       "          2.6555e-01,  5.1245e-01,  6.6751e-01, -2.6291e-01],\n",
-       "        [ 1.3811e-01,  7.4522e-01,  4.9935e-01,  3.3878e-01,  1.8501e-01,\n",
-       "          2.2269e-02, -2.0328e-01,  1.4629e-01, -2.2957e-01,  4.1197e-01,\n",
-       "         -1.9555e-01, -4.7609e-01,  9.0206e-02, -8.8568e-01, -2.1618e-01,\n",
-       "          2.8882e-01, -5.4335e-01, -6.6301e-01,  4.9990e-01, -4.0144e-01,\n",
-       "          3.6403e-01,  5.3901e-01,  8.6665e-01, -7.8312e-02],\n",
-       "        [ 1.6493e-02,  4.6157e-01,  2.9500e-02,  2.4190e-01,  6.5753e-01,\n",
-       "          4.3770e-02, -5.3773e-02,  1.8183e-01, -2.5983e-02,  4.1634e-01,\n",
-       "         -3.5218e-01, -5.6129e-01,  4.1452e-01, -1.2265e+00, -5.8544e-01,\n",
-       "          3.6382e-01, -6.4090e-01, -5.8679e-01,  4.3489e-02, -1.1233e-01,\n",
-       "          3.1175e-01,  4.2857e-01,  1.6501e-01, -2.4118e-01],\n",
-       "        [ 9.2361e-02,  6.0196e-01,  1.3081e-02, -8.1091e-02,  4.2342e-01,\n",
-       "         -8.8457e-02, -8.1851e-02, -1.1562e-01, -1.5049e-01,  4.9972e-01,\n",
-       "         -3.0432e-01, -7.8619e-01,  2.1060e-01, -1.0598e+00, -4.6542e-01,\n",
-       "          4.2382e-01, -6.5671e-01, -4.8589e-01,  5.5977e-02, -2.9478e-02,\n",
-       "          8.5718e-02,  4.7685e-01,  4.8351e-01, -2.8142e-01],\n",
-       "        [ 1.3377e-01,  5.4434e-01,  3.4505e-01,  1.1307e-01,  4.4057e-01,\n",
-       "         -7.6075e-03,  1.3841e-01, -1.1497e-01, -1.3177e-01,  8.0254e-01,\n",
-       "         -3.0627e-01, -6.8437e-01,  1.9035e-01, -1.0208e+00, -1.3259e-01,\n",
-       "          5.3231e-01, -4.7814e-01, -5.1266e-01,  2.4646e-02, -3.0552e-01,\n",
-       "          2.7398e-01,  5.8269e-01,  6.5481e-01, -4.2041e-01],\n",
-       "        [ 1.9604e-01,  4.0597e-01,  1.9071e-01, -2.5535e-01,  1.1915e-01,\n",
-       "         -6.7129e-02,  5.4386e-03, -8.2196e-02, -4.2803e-01,  7.0287e-01,\n",
-       "         -3.0026e-01, -7.6001e-01, -5.1471e-03, -7.0283e-01, -9.2978e-02,\n",
-       "          1.2243e-01, -1.8398e-01, -4.7374e-01,  2.7978e-01, -3.6962e-01,\n",
-       "          5.6046e-02,  4.1773e-01,  4.9894e-01, -3.1945e-01],\n",
-       "        [ 1.2657e-01,  3.3224e-01,  6.2830e-02,  1.5718e-01,  4.8844e-01,\n",
-       "         -1.1476e-01, -1.5044e-01,  2.5265e-02, -2.0351e-01,  5.5770e-01,\n",
-       "         -3.6036e-01, -7.4406e-01,  1.6962e-01, -9.6185e-01, -2.9334e-01,\n",
-       "          2.2584e-01, -4.1169e-01, -5.2146e-01,  2.3314e-01, -1.3668e-01,\n",
-       "         -1.9598e-02,  3.8727e-01,  3.6892e-01, -3.3071e-01],\n",
-       "        [ 5.2178e-01,  6.9704e-01,  5.0093e-01,  1.1157e-01,  8.0012e-02,\n",
-       "          3.6931e-01, -6.4927e-02,  1.1126e-01, -2.5117e-01,  5.3017e-01,\n",
-       "         -2.6488e-01, -8.4056e-01,  2.2374e-01, -6.6831e-01, -1.9402e-01,\n",
-       "          7.4174e-02, -4.7763e-01, -2.6912e-01,  5.1009e-01, -5.4239e-01,\n",
-       "          3.0123e-01,  3.7529e-01,  4.1625e-01, -2.0141e-01],\n",
-       "        [ 3.7968e-01,  4.9387e-01,  3.6786e-01, -1.3131e-01,  2.4445e-02,\n",
-       "          2.2155e-01, -4.0087e-02, -1.4872e-01, -5.5030e-01,  6.8958e-01,\n",
-       "         -3.8156e-01, -7.5760e-01,  3.2085e-01, -6.4571e-01,  1.1268e-03,\n",
-       "          3.4251e-02, -2.6440e-01, -2.6374e-01,  5.9787e-01, -4.6502e-01,\n",
-       "          2.0074e-01,  4.5471e-01,  2.4238e-01, -4.3247e-01],\n",
-       "        [ 2.9364e-01,  4.8659e-01,  9.0845e-02,  1.6348e-01,  5.7636e-01,\n",
-       "          4.5485e-01, -1.6781e-01, -1.4557e-01, -8.8814e-02,  6.6351e-01,\n",
-       "         -5.3669e-01, -8.2818e-01,  6.0474e-01, -9.4558e-01, -3.0133e-01,\n",
-       "          3.0310e-01, -5.2493e-01, -2.5948e-01,  1.5857e-01, -4.2695e-01,\n",
-       "          2.1311e-01,  4.6502e-01,  8.7946e-02, -5.5815e-01],\n",
-       "        [ 9.2208e-02,  2.9731e-01,  3.3849e-01, -5.1049e-02,  2.7834e-01,\n",
-       "         -1.1120e-01,  1.1835e-01,  1.3665e-01, -2.1291e-01,  3.5107e-01,\n",
-       "         -9.8108e-02, -5.0180e-01,  2.9894e-01, -7.7726e-01, -8.1317e-02,\n",
-       "          3.5704e-01, -3.6759e-01, -2.2148e-01,  1.1019e-01, -1.4452e-02,\n",
-       "          1.5092e-02,  3.3405e-01,  1.2765e-01, -4.0411e-01],\n",
-       "        [ 2.8927e-02,  4.4180e-01,  1.0994e-01,  5.6124e-01,  4.7174e-01,\n",
-       "          1.9914e-01, -9.5047e-02,  3.1277e-02, -1.8656e-01,  5.0631e-01,\n",
-       "         -3.4353e-01, -5.7425e-01,  4.3409e-01, -8.3343e-01, -1.1627e-01,\n",
-       "          3.1852e-02, -4.1274e-01, -2.6756e-01,  4.9652e-01, -2.6137e-01,\n",
-       "          2.8559e-02,  3.0587e-01,  3.6717e-01, -4.4303e-01],\n",
-       "        [-1.0741e-01,  1.3539e-01,  1.5746e-01,  2.1208e-01,  6.3745e-01,\n",
-       "         -2.1864e-01, -1.8820e-01,  2.1184e-01, -3.6832e-02,  3.0890e-01,\n",
-       "         -2.4719e-03, -3.3573e-01,  1.8479e-01, -9.2119e-01, -2.3361e-01,\n",
-       "          8.9827e-02, -5.4372e-01, -4.4935e-01,  3.2967e-01, -9.2807e-02,\n",
-       "          9.9241e-02,  4.1705e-01,  2.4728e-01, -4.8119e-01],\n",
-       "        [ 2.8125e-01,  5.3276e-01,  5.0110e-02,  2.0471e-01,  5.7750e-01,\n",
-       "          4.6670e-02, -2.1400e-01,  6.8794e-03, -6.8737e-02,  4.2138e-01,\n",
-       "         -3.1261e-01, -7.3709e-01,  4.2001e-01, -9.9757e-01, -4.8091e-01,\n",
-       "          2.9960e-01, -6.2133e-01, -4.0566e-01,  3.2191e-01, -1.0219e-02,\n",
-       "          1.2901e-01,  3.9601e-01,  1.6291e-01, -3.3871e-01],\n",
-       "        [ 2.9181e-01,  5.5400e-01,  3.0462e-01,  2.2431e-02,  2.8480e-01,\n",
-       "          4.4624e-01, -2.8859e-01, -1.4629e-01, -4.3573e-02,  2.9742e-01,\n",
-       "         -1.0100e-01, -4.3070e-01,  4.6713e-01, -3.7132e-01, -8.6748e-02,\n",
-       "          2.5666e-01, -3.5361e-01, -2.3917e-02,  3.0071e-01, -3.2420e-01,\n",
-       "          1.3375e-01,  3.4475e-01,  3.0642e-01, -4.3496e-01],\n",
-       "        [-7.7723e-04,  2.3828e-01,  2.3124e-01,  4.1347e-01,  6.8455e-01,\n",
-       "         -9.8319e-03,  1.3403e-01,  1.8460e-02, -1.4025e-01,  5.9780e-01,\n",
-       "         -3.7015e-01, -5.7865e-01,  4.9211e-01, -1.1262e+00, -2.1693e-01,\n",
-       "          3.2002e-01, -2.9313e-01, -3.1941e-01,  9.8446e-02, -6.2767e-02,\n",
-       "         -9.8636e-03,  3.5712e-01,  2.8833e-01, -5.3506e-01]], device='cuda:0',\n",
-       "       grad_fn=<AddmmBackward>)"
-      ]
-     },
-     "execution_count": 25,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "t(x)"
+    "tt = t != 3"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 59,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "PosixPath('hej')"
+       "tensor([[ True,  True,  True, False, False, False]])"
       ]
      },
-     "execution_count": 13,
+     "execution_count": 59,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
-    "t.h"
+    "tt"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 43,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "16"
-      ]
-     },
-     "execution_count": 12,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
+   "outputs": [],
    "source": [
-    "t.batch_size"
+    "t = torch.cat((t, t))"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": 44,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "PosixPath('hej')"
+       "torch.Size([2, 6])"
       ]
      },
-     "execution_count": 11,
+     "execution_count": 44,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
-    "t.h"
+    "t.shape"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 21,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "../text_recognizer/__init__.py\n",
-      "../text_recognizer/callbacks/__init__.py\n",
-      "../text_recognizer/callbacks/wandb_callbacks.py\n",
-      "../text_recognizer/data/image_utils.py\n",
-      "../text_recognizer/data/emnist.py\n",
-      "../text_recognizer/data/iam_lines.py\n",
-      "../text_recognizer/data/download_utils.py\n",
-      "../text_recognizer/data/mappings.py\n",
-      "../text_recognizer/data/iam_preprocessor.py\n",
-      "../text_recognizer/data/__init__.py\n",
-      "../text_recognizer/data/make_wordpieces.py\n",
-      "../text_recognizer/data/iam_paragraphs.py\n",
-      "../text_recognizer/data/sentence_generator.py\n",
-      "../text_recognizer/data/emnist_lines.py\n",
-      "../text_recognizer/data/build_transitions.py\n",
-      "../text_recognizer/data/base_dataset.py\n",
-      "../text_recognizer/data/base_data_module.py\n",
-      "../text_recognizer/data/iam.py\n",
-      "../text_recognizer/data/iam_synthetic_paragraphs.py\n",
-      "../text_recognizer/data/transforms.py\n",
-      "../text_recognizer/data/iam_extended_paragraphs.py\n",
-      "../text_recognizer/networks/__init__.py\n",
-      "../text_recognizer/networks/util.py\n",
-      "../text_recognizer/networks/cnn_tranformer.py\n",
-      "../text_recognizer/networks/encoders/__init__.py\n",
-      "../text_recognizer/networks/encoders/efficientnet/efficientnet.py\n",
-      "../text_recognizer/networks/encoders/efficientnet/__init__.py\n",
-      "../text_recognizer/networks/encoders/efficientnet/utils.py\n",
-      "../text_recognizer/networks/encoders/efficientnet/mbconv.py\n",
-      "../text_recognizer/networks/loss/__init__.py\n",
-      "../text_recognizer/networks/loss/label_smoothing_loss.py\n",
-      "../text_recognizer/networks/vqvae/__init__.py\n",
-      "../text_recognizer/networks/vqvae/decoder.py\n",
-      "../text_recognizer/networks/vqvae/vqvae.py\n",
-      "../text_recognizer/networks/vqvae/vector_quantizer.py\n",
-      "../text_recognizer/networks/vqvae/encoder.py\n",
-      "../text_recognizer/networks/transformer/__init__.py\n",
-      "../text_recognizer/networks/transformer/layers.py\n",
-      "../text_recognizer/networks/transformer/residual.py\n",
-      "../text_recognizer/networks/transformer/attention.py\n",
-      "../text_recognizer/networks/transformer/transformer.py\n",
-      "../text_recognizer/networks/transformer/vit.py\n",
-      "../text_recognizer/networks/transformer/mlp.py\n",
-      "../text_recognizer/networks/transformer/norm.py\n",
-      "../text_recognizer/networks/transformer/positional_encodings/positional_encoding.py\n",
-      "../text_recognizer/networks/transformer/positional_encodings/__init__.py\n",
-      "../text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py\n",
-      "../text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py\n",
-      "../text_recognizer/networks/transformer/nystromer/__init__.py\n",
-      "../text_recognizer/networks/transformer/nystromer/nystromer.py\n",
-      "../text_recognizer/networks/transformer/nystromer/attention.py\n",
-      "../text_recognizer/models/__init__.py\n",
-      "../text_recognizer/models/base.py\n",
-      "../text_recognizer/models/vqvae.py\n",
-      "../text_recognizer/models/transformer.py\n",
-      "../text_recognizer/models/dino.py\n",
-      "../text_recognizer/models/metrics.py\n"
-     ]
-    }
-   ],
-   "source": [
-    "for f in Path(\"../text_recognizer\").glob(\"**/*.py\"):\n",
-    "    print(f)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 45,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "<generator object Path.glob at 0x7ff8bb9ce5f0>"
+       "tensor([[[[ True, False, False, False, False, False],\n",
+       "          [ True,  True, False, False, False, False],\n",
+       "          [ True,  True,  True, False, False, False],\n",
+       "          [ True,  True,  True, False, False, False],\n",
+       "          [ True,  True,  True, False, False, False],\n",
+       "          [ True,  True,  True, False, False, False]]],\n",
+       "\n",
+       "\n",
+       "        [[[ True, False, False, False, False, False],\n",
+       "          [ True,  True, False, False, False, False],\n",
+       "          [ True,  True,  True, False, False, False],\n",
+       "          [ True,  True,  True, False, False, False],\n",
+       "          [ True,  True,  True, False, False, False],\n",
+       "          [ True,  True,  True, False, False, False]]]])"
       ]
      },
-     "execution_count": 12,
+     "execution_count": 45,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
-    "Path(\"..\").glob(\"**/*.py\")"
+    "target_padding_mask(t, 3)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
-    "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "metadata": {
-    "scrolled": false
-   },
-   "outputs": [],
-   "source": [
-    "en = EfficientNet(\"b0\")"
+    "target_padding_mask()"
    ]
   },
   {
@@ -1404,7 +1171,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.9.5"
+   "version": "3.9.6"
   }
  },
  "nbformat": 4,
diff --git a/noxfile.py b/noxfile.py
index a90d53b..d14fefb 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -2,7 +2,6 @@
 import tempfile
 from typing import Any
 
-
 import nox
 from nox.sessions import Session
 
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index 506036e..f7457e4 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -47,6 +47,8 @@ def load_metadata(
 class Preprocessor:
     """A preprocessor for the IAM dataset."""
 
+    # TODO: attrs
+
     def __init__(
         self,
         data_dir: Union[str, Path],
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py
index 5c13e9a..e030cb8 100644
--- a/text_recognizer/networks/cnn_tranformer.py
+++ b/text_recognizer/networks/cnn_tranformer.py
@@ -3,6 +3,7 @@ import math
 from typing import Tuple, Type
 
 import attr
+import torch
 from torch import nn, Tensor
 
 from text_recognizer.data.mappings import AbstractMapping
@@ -18,13 +19,19 @@ class CnnTransformer(nn.Module):
     def __attrs_pre_init__(self) -> None:
         super().__init__()
 
-    # Parameters,
+    # Parameters and placeholders,
     input_dims: Tuple[int, int, int] = attr.ib()
     hidden_dim: int = attr.ib()
     dropout_rate: float = attr.ib()
     max_output_len: int = attr.ib()
     num_classes: int = attr.ib()
     padding_idx: int = attr.ib()
+    start_token: str = attr.ib()
+    start_index: int = attr.ib(init=False, default=None)
+    end_token: str = attr.ib()
+    end_index: int = attr.ib(init=False, default=None)
+    pad_token: str = attr.ib()
+    pad_index: int = attr.ib(init=False, default=None)
 
     # Modules.
     encoder: Type[nn.Module] = attr.ib()
@@ -38,6 +45,9 @@ class CnnTransformer(nn.Module):
 
     def __attrs_post_init__(self) -> None:
         """Post init configuration."""
+        self.start_index = int(self.mapping.get_index(self.start_token))
+        self.end_index = int(self.mapping.get_index(self.end_token))
+        self.pad_index = int(self.mapping.get_index(self.pad_token))
         # Latent projector for down sampling number of filters and 2d
         # positional encoding.
         self.latent_encoder = nn.Sequential(
@@ -99,20 +109,20 @@ class CnnTransformer(nn.Module):
         z = self.encoder(x)
         z = self.latent_encoder(z)
 
-        # Permute tensor from [B, E, Ho * Wo] to [Sx, B, E]
-        z = z.permute(2, 0, 1)
+        # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
+        z = z.permute(0, 2, 1)
         return z
 
-    def decode(self, z: Tensor, trg: Tensor) -> Tensor:
+    def decode(self, z: Tensor, context: Tensor) -> Tensor:
         """Decodes latent images embedding into word pieces.
 
         Args:
             z (Tensor): Latent images embedding.
-            trg (Tensor): Word embeddings.
+            context (Tensor): Word embeddings.
 
         Shapes:
             - z: :math: `(B, Sx, E)`
-            - trg: :math: `(B, Sy)`
+            - context: :math: `(B, Sy)`
             - out: :math: `(B, Sy, T)`
 
             where Sy is the length of the output and T is the number of tokens.
@@ -120,32 +130,69 @@ class CnnTransformer(nn.Module):
         Returns:
             Tensor: Sequence of word piece embeddings.
         """
-        trg_mask = trg != self.padding_idx
-        trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim)
-        trg = self.token_pos_encoder(trg)
-        out = self.decoder(x=trg, context=z, mask=trg_mask)
+        context_mask = context != self.padding_idx
+        context = self.token_embedding(context) * math.sqrt(self.hidden_dim)
+        context = self.token_pos_encoder(context)
+        out = self.decoder(x=context, context=z, mask=context_mask)
         logits = self.head(out)
         return logits
 
-    def forward(self, x: Tensor, trg: Tensor) -> Tensor:
+    def forward(self, x: Tensor, context: Tensor) -> Tensor:
         """Encodes images into word piece logtis.
 
         Args:
             x (Tensor): Input image(s).
-            trg (Tensor): Target word embeddings.
+            context (Tensor): Target word embeddings.
 
         Shapes:
             - x: :math: `(B, C, H, W)`
-            - trg: :math: `(B, Sy, T)`
+            - context: :math: `(B, Sy, T)`
 
             where B is the batch size, C is the number of input channels, H is
             the image height and W is the image width.
+
+        Returns:
+            Tensor: Sequence of logits.
         """
         z = self.encode(x)
-        logits = self.decode(z, trg)
+        logits = self.decode(z, context)
         return logits
 
     def predict(self, x: Tensor) -> Tensor:
-        """Predicts text in image."""
-        # TODO: continue here!!!!!!!!!
-        pass
+        """Predicts text in image.
+        
+        Args:
+            x (Tensor): Image(s) to extract text from.
+
+        Shapes:
+            - x: :math: `(B, H, W)`
+            - output: :math: `(B, S)`
+
+        Returns:
+            Tensor: A tensor of token indices of the predictions from the model.
+        """
+        bsz = x.shape[0]
+
+        # Encode image(s) to latent vectors.
+        z = self.encode(x)
+
+        # Create a placeholder matrix for storing outputs from the network
+        output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
+        output[:, 0] = self.start_index
+
+        for i in range(1, self.max_output_len):
+            context = output[:, :i]  # (bsz, i)
+            logits = self.decode(z, context)  # (i, bsz, c)
+            tokens = torch.argmax(logits, dim=-1)  # (i, bsz)
+            output[:, i : i + 1] = tokens[-1:]
+
+            # Early stopping of prediction loop if token is end or padding token.
+            if (output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index).all():
+                break
+
+        # Set all tokens after end token to pad token.
+        for i in range(1, self.max_output_len):
+            idx = (output[:, i -1] == self.end_index | output[:, i - 1] == self.pad_index)
+            output[idx, i] = self.pad_index
+
+        return output
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index 59598b5..6719efb 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -41,7 +41,7 @@ class EfficientNet(nn.Module):
         self.bn_momentum = bn_momentum
         self.bn_eps = bn_eps
         self._conv_stem: nn.Sequential = None
-        self._blocks: nn.Sequential = None
+        self._blocks: nn.ModuleList = None
         self._conv_head: nn.Sequential = None
         self._build()
 
-- 
cgit v1.2.3-70-g09d2