summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 18:18:48 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 18:18:48 +0200
commitbd4bd443f339e95007bfdabf3e060db720f4d4b9 (patch)
treee55cb3744904f7c2a0348b100c7e92a65e538a16
parent75801019981492eedf9280cb352eea3d8e99b65f (diff)
Training working, multiple bug fixes
-rw-r--r--notebooks/00-scratch-pad.ipynb304
-rw-r--r--notebooks/03-look-at-iam-paragraphs.ipynb183
-rw-r--r--notebooks/05c-test-model-end-to-end.ipynb448
-rw-r--r--poetry.lock58
-rw-r--r--pyproject.toml10
-rw-r--r--text_recognizer/criterions/label_smoothing.py38
-rw-r--r--text_recognizer/data/base_data_module.py6
-rw-r--r--text_recognizer/data/base_mapping.py37
-rw-r--r--text_recognizer/data/download_utils.py2
-rw-r--r--text_recognizer/data/emnist_mapping.py37
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py3
-rw-r--r--text_recognizer/data/iam_lines.py2
-rw-r--r--text_recognizer/data/iam_paragraphs.py12
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py4
-rw-r--r--text_recognizer/data/make_wordpieces.py2
-rw-r--r--text_recognizer/data/mappings.py156
-rw-r--r--text_recognizer/data/transforms.py8
-rw-r--r--text_recognizer/data/word_piece_mapping.py93
-rw-r--r--text_recognizer/models/base.py20
-rw-r--r--text_recognizer/models/transformer.py36
-rw-r--r--text_recognizer/networks/conv_transformer.py42
-rw-r--r--text_recognizer/networks/encoders/efficientnet/mbconv.py9
-rw-r--r--text_recognizer/networks/transformer/layers.py27
-rw-r--r--training/__init__.py (renamed from training/conf/callbacks/wandb/image_reconstructions.yaml)0
-rw-r--r--training/callbacks/wandb_callbacks.py83
-rw-r--r--training/conf/callbacks/checkpoint.yaml2
-rw-r--r--training/conf/callbacks/wandb_checkpoints.yaml (renamed from training/conf/callbacks/wandb/checkpoints.yaml)0
-rw-r--r--training/conf/callbacks/wandb_code.yaml (renamed from training/conf/callbacks/wandb/code.yaml)0
-rw-r--r--training/conf/callbacks/wandb_image_reconstructions.yaml0
-rw-r--r--training/conf/callbacks/wandb_ocr.yaml8
-rw-r--r--training/conf/callbacks/wandb_ocr_predictions.yaml (renamed from training/conf/callbacks/wandb/ocr_predictions.yaml)0
-rw-r--r--training/conf/callbacks/wandb_watch.yaml (renamed from training/conf/callbacks/wandb/watch.yaml)0
-rw-r--r--training/conf/config.yaml21
-rw-r--r--training/conf/criterion/label_smoothing.yaml5
-rw-r--r--training/conf/datamodule/iam_extended_paragraphs.yaml3
-rw-r--r--training/conf/lr_scheduler/one_cycle.yaml4
-rw-r--r--training/conf/mapping/word_piece.yaml4
-rw-r--r--training/conf/model/lit_transformer.yaml2
-rw-r--r--training/conf/network/conv_transformer.yaml1
-rw-r--r--training/conf/network/decoder/transformer_decoder.yaml1
-rw-r--r--training/run.py19
-rw-r--r--training/utils.py23
42 files changed, 1119 insertions, 594 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb
index 0350727..a193107 100644
--- a/notebooks/00-scratch-pad.ipynb
+++ b/notebooks/00-scratch-pad.ipynb
@@ -2,18 +2,9 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 1,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "The autoreload extension is already loaded. To reload it, use:\n",
- " %reload_ext autoreload\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
@@ -33,8 +24,295 @@
"\n",
"from text_recognizer.networks.transformer.vit import ViT\n",
"from text_recognizer.networks.transformer.transformer import Transformer\n",
- "from text_recognizer.networks.transformer.layers import Decoder\n",
- "from text_recognizer.networks.transformer.nystromer.nystromer import Nystromer"
+ "from text_recognizer.networks.transformer.layers import Decoder"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.cuda.is_available()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "loss = nn.CrossEntropyLoss()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "o = torch.randn((4, 5, 4))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = torch.randint(0, 5, (4, 4))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 53,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([4, 5, 4])"
+ ]
+ },
+ "execution_count": 53,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "o.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([4, 4])"
+ ]
+ },
+ "execution_count": 54,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 55,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[0, 1, 3, 2],\n",
+ " [1, 4, 4, 4],\n",
+ " [1, 4, 2, 1],\n",
+ " [2, 0, 4, 4]])"
+ ]
+ },
+ "execution_count": 55,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 56,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[ 0.0647, -1.3831, 0.0266, 0.8528],\n",
+ " [ 1.4976, 0.4153, 1.0353, 0.0154],\n",
+ " [ 1.4562, -0.3568, 0.3599, -0.6222],\n",
+ " [ 0.2773, 0.4563, 0.9282, -2.1445],\n",
+ " [ 0.5191, 0.3683, -0.3469, 0.1355]],\n",
+ "\n",
+ " [[ 0.0424, -0.3215, 0.5662, -0.4217],\n",
+ " [ 2.0793, 1.2817, 0.1559, -0.6900],\n",
+ " [-1.1751, -0.3359, 1.7875, -0.3671],\n",
+ " [-0.4553, -0.3952, -0.8633, 0.1538],\n",
+ " [-1.3862, 0.4255, -2.2948, 0.0312]],\n",
+ "\n",
+ " [[-1.4257, 2.2662, 0.2670, -0.4330],\n",
+ " [-0.3244, -0.8669, -0.2571, 0.8028],\n",
+ " [ 0.9109, -0.2289, -1.2095, -0.9761],\n",
+ " [-0.0156, 1.2403, -1.1967, 0.6841],\n",
+ " [-0.8185, 0.2967, -2.1639, -0.7903]],\n",
+ "\n",
+ " [[-1.0425, 0.1426, 0.1383, 0.9784],\n",
+ " [-1.2853, 1.4123, -0.2272, -0.3335],\n",
+ " [ 1.5751, -0.7663, 0.9610, 0.5686],\n",
+ " [ 0.9697, -1.5515, -0.8658, -0.5882],\n",
+ " [-1.2467, 0.0539, 0.1208, -1.0297]]])"
+ ]
+ },
+ "execution_count": 56,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "o"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 57,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(1.8355)"
+ ]
+ },
+ "execution_count": 57,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "loss(o, t)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 63,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "TypeError",
+ "evalue": "unsupported operand type(s) for |: 'int' and 'Tensor'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m/tmp/ipykernel_9275/1867668791.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m: unsupported operand type(s) for |: 'int' and 'Tensor'"
+ ]
+ }
+ ],
+ "source": [
+ "t[:, 2] == 2 | t[:, 2] == 1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 60,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([4, 1])"
+ ]
+ },
+ "execution_count": 60,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.argmax(o, dim=-1)[:, -1:].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class LabelSmoothingLossCanonical(nn.Module):\n",
+ " def __init__(self, smoothing=0.0, dim=-1):\n",
+ " super(LabelSmoothingLossCanonical, self).__init__()\n",
+ " self.confidence = 1.0 - smoothing\n",
+ " self.smoothing = smoothing\n",
+ " self.dim = dim\n",
+ "\n",
+ " def forward(self, pred, target):\n",
+ " pred = pred.log_softmax(dim=self.dim)\n",
+ " with torch.no_grad():\n",
+ " # true_dist = pred.data.clone()\n",
+ " true_dist = torch.zeros_like(pred)\n",
+ " print(true_dist.shape)\n",
+ " true_dist.scatter_(1, target.unsqueeze(1), self.confidence)\n",
+ " print(true_dist.shape)\n",
+ " print(true_dist)\n",
+ " true_dist.masked_fill_((target == 4).unsqueeze(1), 0)\n",
+ " print(true_dist)\n",
+ " true_dist += self.smoothing / pred.size(self.dim)\n",
+ " return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "l = LabelSmoothingLossCanonical(0.1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 45,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([1, 5, 4])\n",
+ "torch.Size([1, 5, 4])\n",
+ "tensor([[[0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.9000, 0.9000, 0.0000, 0.9000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.9000, 0.0000]]])\n",
+ "tensor([[[0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.9000, 0.9000, 0.0000, 0.9000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000]]])\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "tensor(0.9438)"
+ ]
+ },
+ "execution_count": 45,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "l(o, t)"
]
},
{
diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb
index 76ca6b1..ed67e9c 100644
--- a/notebooks/03-look-at-iam-paragraphs.ipynb
+++ b/notebooks/03-look-at-iam-paragraphs.ipynb
@@ -2,24 +2,10 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"id": "6ce2519f",
"metadata": {},
- "outputs": [
- {
- "ename": "ModuleNotFoundError",
- "evalue": "No module named 'loguru.logger'",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m/tmp/ipykernel_3883/2979229631.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'..'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miam_paragraphs\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIAMParagraphs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miam_synthetic_paragraphs\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIAMSyntheticParagraphs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miam_extended_paragraphs\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIAMExtendedParagraphs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/projects/text-recognizer/text_recognizer/data/iam_paragraphs.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0memnist\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0memnist_mapping\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miam\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIAM\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmappings\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWordPieceMapping\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWordPiece\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/projects/text-recognizer/text_recognizer/data/mappings.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mattr\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mloguru\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogger\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mlog\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'loguru.logger'"
- ]
- }
- ],
+ "outputs": [],
"source": [
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICE'] = ''\n",
@@ -62,42 +48,12 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"id": "c6188bce",
"metadata": {
"scrolled": true
},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2021-07-30 23:09:28.009 | DEBUG | text_recognizer.data.mappings:__attrs_post_init__:90 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n",
- "2021-07-30 23:09:28.117 | DEBUG | text_recognizer.data.mappings:__attrs_post_init__:90 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n",
- "2021-07-30 23:09:28.277 | INFO | text_recognizer.data.iam_paragraphs:setup:103 - Loading IAM paragraph regions and lines for None...\n",
- "2021-07-30 23:09:47.357 | DEBUG | text_recognizer.data.mappings:__attrs_post_init__:90 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n",
- "2021-07-30 23:09:50.514 | DEBUG | text_recognizer.data.mappings:__attrs_post_init__:90 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n",
- "2021-07-30 23:09:50.612 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:67 - IAM Synthetic dataset steup for stage None...\n",
- "2021-07-30 23:10:02.137 | DEBUG | text_recognizer.data.mappings:__attrs_post_init__:90 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "IAM Original and Synthetic Paragraphs Dataset\n",
- "Num classes: 1006\n",
- "Dims: (1, 576, 640)\n",
- "Output dims: (682, 1)\n",
- "Train/val/test sizes: 19959, 262, 231\n",
- "Train Batch x stats: (torch.Size([1, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0026), tensor(0.0239), tensor(0.7412))\n",
- "Train Batch y stats: (torch.Size([1, 451]), torch.int64, tensor(1), tensor(1002))\n",
- "Test Batch x stats: (torch.Size([1, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0372), tensor(0.0767), tensor(0.8118))\n",
- "Test Batch y stats: (torch.Size([1, 451]), torch.int64, tensor(1), tensor(1003))\n",
- "\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"dataset = IAMExtendedParagraphs(batch_size=1, word_pieces=True)\n",
"dataset.prepare_data()\n",
@@ -107,21 +63,10 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"id": "55b26b5d",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "1006"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"len(dataset.mapping)"
]
@@ -161,7 +106,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"id": "0cf22683",
"metadata": {},
"outputs": [],
@@ -171,146 +116,52 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"id": "8541e6ee",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([1, 576, 640])"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"x.shape"
]
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"id": "40447ce6",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([1002, 59, 6, 1, 54, 7, 2, 41, 36, 15, 4, 3,\n",
- " 842, 2, 46, 230, 65, 439, 97, 784, 779, 7, 1003, 1,\n",
- " 218, 18, 12, 11, 1, 20, 26, 54, 23, 36, 4, 1,\n",
- " 511, 679, 352, 324, 4, 43, 172, 33, 14, 81, 84, 1,\n",
- " 47, 281, 59, 1003, 890, 350, 14, 49, 33, 14, 81, 84,\n",
- " 1, 20, 15, 95, 23, 21, 2, 24, 21, 59, 1, 2,\n",
- " 7, 31, 54, 7, 15, 20, 54, 13, 33, 3, 1003, 784,\n",
- " 68, 409, 196, 663, 2, 42, 1, 9, 41, 31, 89, 14,\n",
- " 1003, 827, 89, 35, 1, 54, 7, 15, 23, 54, 7, 16,\n",
- " 7, 21, 15, 4, 14, 42, 1, 24, 31, 247, 26, 89,\n",
- " 28, 1003, 1, 31, 7, 21, 15, 54, 7, 2, 33, 3,\n",
- " 867, 166, 2, 96, 15, 2, 10, 928, 2, 88, 16, 1003,\n",
- " 3, 842, 2, 46, 230, 115, 52, 26, 52, 89, 53, 105,\n",
- " 170, 1, 9, 41, 31, 89, 1, 17, 7, 26, 20, 54,\n",
- " 15, 16, 7, 21, 15, 201, 1003, 3, 252, 176, 44, 1,\n",
- " 9, 41, 31, 89, 28, 1, 20, 2, 2, 24, 31, 23,\n",
- " 20, 15, 23, 24, 21, 201, 3, 108, 23, 216, 2, 62,\n",
- " 13, 1003, 608, 30, 16, 105, 28, 1, 9, 41, 31, 89,\n",
- " 663, 14, 82, 26, 58, 15, 97, 2, 1003, 10, 1, 26,\n",
- " 2, 13, 31, 47, 24, 36, 24, 46, 13, 4, 1, 9,\n",
- " 41, 31, 89, 14, 87, 664, 1, 2, 31, 23, 7, 21,\n",
- " 31, 7, 201, 1, 33, 33, 33, 33, 1000, 1001, 1001, 1001,\n",
- " 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001,\n",
- " 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001,\n",
- " 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001,\n",
- " 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001,\n",
- " 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001,\n",
- " 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001,\n",
- " 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001,\n",
- " 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001,\n",
- " 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001,\n",
- " 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001,\n",
- " 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001,\n",
- " 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001,\n",
- " 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001,\n",
- " 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001,\n",
- " 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001,\n",
- " 1001, 1001, 1001, 1001, 1001, 1001, 1001])"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"y"
]
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"id": "016e8c81",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "451"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"len(y)"
]
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": null,
"id": "7aa8c021",
"metadata": {
"scrolled": true
},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([1, 576, 640])"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"x.shape"
]
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": null,
"id": "7ef93252",
"metadata": {},
- "outputs": [
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 864x864 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
+ "outputs": [],
"source": [
"_plot(x[0], vmax=1, title=dataset.mapping.get_text(y))"
]
diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb
index b652bdd..e3e92e2 100644
--- a/notebooks/05c-test-model-end-to-end.ipynb
+++ b/notebooks/05c-test-model-end-to-end.ipynb
@@ -2,10 +2,19 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 4,
"id": "1e40a88b",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ }
+ ],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
@@ -25,7 +34,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 5,
"id": "d3a6146b-94b1-4618-a4e4-00f8e23ffdb0",
"metadata": {},
"outputs": [],
@@ -45,32 +54,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "mapping:\n",
- " _target_: text_recognizer.data.mappings.WordPieceMapping\n",
- " num_features: 1000\n",
- " tokens: iamdb_1kwp_tokens_1000.txt\n",
- " lexicon: iamdb_1kwp_lex_1000.txt\n",
- " data_dir: null\n",
- " use_words: false\n",
- " prepend_wordsep: false\n",
- " special_tokens:\n",
- " - <s>\n",
- " - <e>\n",
- " - <p>\n",
- " extra_symbols:\n",
- " - \\n\n",
"_target_: text_recognizer.models.transformer.TransformerLitModel\n",
"interval: step\n",
"monitor: val/loss\n",
- "ignore_tokens:\n",
- "- <s>\n",
- "- <e>\n",
- "- <p>\n",
"start_token: <s>\n",
"end_token: <e>\n",
"pad_token: <p>\n",
"\n",
- "{'mapping': {'_target_': 'text_recognizer.data.mappings.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\\\n']}, '_target_': 'text_recognizer.models.transformer.TransformerLitModel', 'interval': 'step', 'monitor': 'val/loss', 'ignore_tokens': ['<s>', '<e>', '<p>'], 'start_token': '<s>', 'end_token': '<e>', 'pad_token': '<p>'}\n"
+ "{'_target_': 'text_recognizer.models.transformer.TransformerLitModel', 'interval': 'step', 'monitor': 'val/loss', 'start_token': '<s>', 'end_token': '<e>', 'pad_token': '<p>'}\n"
]
}
],
@@ -85,6 +76,20 @@
{
"cell_type": "code",
"execution_count": null,
+ "id": "5e6b49ce-7685-4491-bd0a-51487f06a237",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# context initialization\n",
+ "with initialize(config_path=\"../training/conf/mapping/\", job_name=\"test_app\"):\n",
+ " cfg = compose(config_name=\"word_piece\")\n",
+ " print(OmegaConf.to_yaml(cfg))\n",
+ " print(cfg)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
"id": "9c797159-845e-42c6-bd65-1c976ad627cd",
"metadata": {},
"outputs": [],
@@ -98,6 +103,405 @@
},
{
"cell_type": "code",
+ "execution_count": 6,
+ "id": "764c8736-7d68-4261-a57d-face10ebbf42",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "callbacks:\n",
+ " model_checkpoint:\n",
+ " _target_: pytorch_lightning.callbacks.ModelCheckpoint\n",
+ " monitor: val/loss\n",
+ " save_top_k: 1\n",
+ " save_last: true\n",
+ " mode: min\n",
+ " verbose: false\n",
+ " dirpath: checkpoints/\n",
+ " filename:\n",
+ " epoch:02d: null\n",
+ " learning_rate_monitor:\n",
+ " _target_: pytorch_lightning.callbacks.LearningRateMonitor\n",
+ " logging_interval: step\n",
+ " log_momentum: false\n",
+ " watch_model:\n",
+ " _target_: callbacks.wandb_callbacks.WatchModel\n",
+ " log: all\n",
+ " log_freq: 100\n",
+ " upload_code_as_artifact:\n",
+ " _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact\n",
+ " project_dir: ${work_dir}/text_recognizer\n",
+ " upload_ckpts_as_artifact:\n",
+ " _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n",
+ " ckpt_dir: checkpoints/\n",
+ " upload_best_only: true\n",
+ " log_text_predictions:\n",
+ " _target_: callbacks.wandb_callbacks.LogTextPredictions\n",
+ " num_samples: 8\n",
+ "criterion:\n",
+ " _target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss\n",
+ " smoothing: 0.1\n",
+ " ignore_index: 1002\n",
+ "datamodule:\n",
+ " _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs\n",
+ " batch_size: 8\n",
+ " num_workers: 12\n",
+ " train_fraction: 0.8\n",
+ " augment: true\n",
+ " pin_memory: false\n",
+ "logger:\n",
+ " wandb:\n",
+ " _target_: pytorch_lightning.loggers.wandb.WandbLogger\n",
+ " project: text-recognizer\n",
+ " name: null\n",
+ " save_dir: .\n",
+ " offline: false\n",
+ " id: null\n",
+ " log_model: false\n",
+ " prefix: ''\n",
+ " job_type: train\n",
+ " group: ''\n",
+ " tags: []\n",
+ "lr_scheduler:\n",
+ " _target_: torch.optim.lr_scheduler.OneCycleLR\n",
+ " max_lr: 0.001\n",
+ " total_steps: null\n",
+ " epochs: 512\n",
+ " steps_per_epoch: 4992\n",
+ " pct_start: 0.3\n",
+ " anneal_strategy: cos\n",
+ " cycle_momentum: true\n",
+ " base_momentum: 0.85\n",
+ " max_momentum: 0.95\n",
+ " div_factor: 25.0\n",
+ " final_div_factor: 10000.0\n",
+ " three_phase: true\n",
+ " last_epoch: -1\n",
+ " verbose: false\n",
+ "mapping:\n",
+ " _target_: text_recognizer.data.word_piece_mapping.WordPieceMapping\n",
+ " num_features: 1000\n",
+ " tokens: iamdb_1kwp_tokens_1000.txt\n",
+ " lexicon: iamdb_1kwp_lex_1000.txt\n",
+ " data_dir: null\n",
+ " use_words: false\n",
+ " prepend_wordsep: false\n",
+ " special_tokens:\n",
+ " - <s>\n",
+ " - <e>\n",
+ " - <p>\n",
+ " extra_symbols:\n",
+ " - '\n",
+ "\n",
+ " '\n",
+ "model:\n",
+ " _target_: text_recognizer.models.transformer.TransformerLitModel\n",
+ " interval: step\n",
+ " monitor: val/loss\n",
+ " max_output_len: 451\n",
+ " start_token: <s>\n",
+ " end_token: <e>\n",
+ " pad_token: <p>\n",
+ "network:\n",
+ " encoder:\n",
+ " _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet\n",
+ " arch: b0\n",
+ " out_channels: 1280\n",
+ " stochastic_dropout_rate: 0.2\n",
+ " bn_momentum: 0.99\n",
+ " bn_eps: 0.001\n",
+ " decoder:\n",
+ " _target_: text_recognizer.networks.transformer.Decoder\n",
+ " dim: 96\n",
+ " depth: 2\n",
+ " num_heads: 8\n",
+ " attn_fn: text_recognizer.networks.transformer.attention.Attention\n",
+ " attn_kwargs:\n",
+ " dim_head: 16\n",
+ " dropout_rate: 0.2\n",
+ " norm_fn: torch.nn.LayerNorm\n",
+ " ff_fn: text_recognizer.networks.transformer.mlp.FeedForward\n",
+ " ff_kwargs:\n",
+ " dim_out: null\n",
+ " expansion_factor: 4\n",
+ " glu: true\n",
+ " dropout_rate: 0.2\n",
+ " cross_attend: true\n",
+ " pre_norm: true\n",
+ " rotary_emb: null\n",
+ " _target_: text_recognizer.networks.conv_transformer.ConvTransformer\n",
+ " input_dims:\n",
+ " - 1\n",
+ " - 576\n",
+ " - 640\n",
+ " hidden_dim: 96\n",
+ " dropout_rate: 0.2\n",
+ " num_classes: 1006\n",
+ " pad_index: 1002\n",
+ "optimizer:\n",
+ " _target_: madgrad.MADGRAD\n",
+ " lr: 0.001\n",
+ " momentum: 0.9\n",
+ " weight_decay: 0\n",
+ " eps: 1.0e-06\n",
+ "trainer:\n",
+ " _target_: pytorch_lightning.Trainer\n",
+ " stochastic_weight_avg: false\n",
+ " auto_scale_batch_size: binsearch\n",
+ " auto_lr_find: false\n",
+ " gradient_clip_val: 0\n",
+ " fast_dev_run: false\n",
+ " gpus: 1\n",
+ " precision: 16\n",
+ " max_epochs: 512\n",
+ " terminate_on_nan: true\n",
+ " weights_summary: top\n",
+ " limit_train_batches: 1.0\n",
+ " limit_val_batches: 1.0\n",
+ " limit_test_batches: 1.0\n",
+ " resume_from_checkpoint: null\n",
+ "seed: 4711\n",
+ "tune: false\n",
+ "train: true\n",
+ "test: true\n",
+ "logging: INFO\n",
+ "debug: false\n",
+ "\n",
+ "{'callbacks': {'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': {'epoch:02d': None}}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_code_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCodeAsArtifact', 'project_dir': '${work_dir}/text_recognizer'}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_text_predictions': {'_target_': 'callbacks.wandb_callbacks.LogTextPredictions', 'num_samples': 8}}, 'criterion': {'_target_': 'text_recognizer.criterions.label_smoothing.LabelSmoothingLoss', 'smoothing': 0.1, 'ignore_index': 1002}, 'datamodule': {'_target_': 'text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs', 'batch_size': 8, 'num_workers': 12, 'train_fraction': 0.8, 'augment': True, 'pin_memory': False}, 'logger': {'wandb': {'_target_': 'pytorch_lightning.loggers.wandb.WandbLogger', 'project': 'text-recognizer', 'name': None, 'save_dir': '.', 'offline': False, 'id': None, 'log_model': False, 'prefix': '', 'job_type': 'train', 'group': '', 'tags': []}}, 'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.OneCycleLR', 'max_lr': 0.001, 'total_steps': None, 'epochs': 512, 'steps_per_epoch': 4992, 'pct_start': 0.3, 'anneal_strategy': 'cos', 'cycle_momentum': True, 'base_momentum': 0.85, 'max_momentum': 0.95, 'div_factor': 25.0, 'final_div_factor': 10000.0, 'three_phase': True, 'last_epoch': -1, 'verbose': False}, 'mapping': {'_target_': 'text_recognizer.data.word_piece_mapping.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}, 'model': {'_target_': 'text_recognizer.models.transformer.TransformerLitModel', 'interval': 'step', 'monitor': 'val/loss', 'max_output_len': 451, 'start_token': '<s>', 'end_token': '<e>', 'pad_token': '<p>'}, 'network': {'encoder': {'_target_': 'text_recognizer.networks.encoders.efficientnet.EfficientNet', 'arch': 'b0', 'out_channels': 1280, 'stochastic_dropout_rate': 0.2, 'bn_momentum': 0.99, 'bn_eps': 0.001}, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'dim': 96, 'depth': 2, 'num_heads': 8, 'attn_fn': 'text_recognizer.networks.transformer.attention.Attention', 'attn_kwargs': {'dim_head': 16, 'dropout_rate': 0.2}, 'norm_fn': 'torch.nn.LayerNorm', 'ff_fn': 'text_recognizer.networks.transformer.mlp.FeedForward', 'ff_kwargs': {'dim_out': None, 'expansion_factor': 4, 'glu': True, 'dropout_rate': 0.2}, 'cross_attend': True, 'pre_norm': True, 'rotary_emb': None}, '_target_': 'text_recognizer.networks.conv_transformer.ConvTransformer', 'input_dims': [1, 576, 640], 'hidden_dim': 96, 'dropout_rate': 0.2, 'num_classes': 1006, 'pad_index': 1002}, 'optimizer': {'_target_': 'madgrad.MADGRAD', 'lr': 0.001, 'momentum': 0.9, 'weight_decay': 0, 'eps': 1e-06}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'stochastic_weight_avg': False, 'auto_scale_batch_size': 'binsearch', 'auto_lr_find': False, 'gradient_clip_val': 0, 'fast_dev_run': False, 'gpus': 1, 'precision': 16, 'max_epochs': 512, 'terminate_on_nan': True, 'weights_summary': 'top', 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'limit_test_batches': 1.0, 'resume_from_checkpoint': None}, 'seed': 4711, 'tune': False, 'train': True, 'test': True, 'logging': 'INFO', 'debug': False}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# context initialization\n",
+ "with initialize(config_path=\"../training/conf/\", job_name=\"test_app\"):\n",
+ " cfg = compose(config_name=\"config\")\n",
+ " print(OmegaConf.to_yaml(cfg))\n",
+ " print(cfg)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "9382f0ab-8760-4d59-b0b5-b8b65dd1ea31",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': {'epoch:02d': None}}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_code_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCodeAsArtifact', 'project_dir': '${work_dir}/text_recognizer'}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_text_predictions': {'_target_': 'callbacks.wandb_callbacks.LogTextPredictions', 'num_samples': 8}}"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "cfg.get(\"callbacks\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "216d5680-66bf-4190-9401-1a59dbbc43af",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "pytorch_lightning.callbacks.ModelCheckpoint\n",
+ "pytorch_lightning.callbacks.LearningRateMonitor\n",
+ "callbacks.wandb_callbacks.WatchModel\n",
+ "callbacks.wandb_callbacks.UploadCodeAsArtifact\n",
+ "callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n",
+ "callbacks.wandb_callbacks.LogTextPredictions\n"
+ ]
+ }
+ ],
+ "source": [
+ "for l in cfg.callbacks.values():\n",
+ " print(l.get(\"_target_\"))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "c1a9aa6b-6405-4ffe-b065-02340762476a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2021-08-03 15:27:02.069 | DEBUG | text_recognizer.data.word_piece_mapping:__init__:37 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n"
+ ]
+ }
+ ],
+ "source": [
+ "mapping = instantiate(cfg.mapping)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "969ba3be-d78f-4b1e-b522-ea8a42669e86",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "network = instantiate(cfg.network)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a23893a9-a0da-4327-a617-dc0c2011e5e8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "OmegaConf.set_struct(cfg, False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "a6fae1fa-492d-4648-80fd-1c0dac659b02",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "datamodule = instantiate(cfg.datamodule, mapping=mapping)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "514053ef-fcac-4f3c-a7c8-72c6927d6798",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2021-08-03 15:28:22.541 | INFO | text_recognizer.data.iam_paragraphs:setup:95 - Loading IAM paragraph regions and lines for None...\n",
+ "2021-08-03 15:28:45.280 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:68 - IAM Synthetic dataset steup for stage None...\n"
+ ]
+ }
+ ],
+ "source": [
+ "datamodule.prepare_data()\n",
+ "datamodule.setup()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "4bad950b-a197-4c60-ad89-903124659a98",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "4992"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(datamodule.train_dataloader())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7db05cbd-48b3-43fa-a99a-353126311879",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mapping"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "f6e01c15-9a1b-4036-87ae-78716c592264",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config = cfg"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "4dc475fc-31f4-487e-88c8-b0f445131f5b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "loss_fn = instantiate(cfg.criterion)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "c5c8ed64-d98c-47b5-baf2-1ba57a6c882f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import hydra"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "b5ff5b24-f804-402b-a8ab-f366443025ca",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ " model = hydra.utils.instantiate(\n",
+ " config.model,\n",
+ " mapping=mapping,\n",
+ " network=network,\n",
+ " loss_fn=loss_fn,\n",
+ " optimizer_config=config.optimizer,\n",
+ " lr_scheduler_config=config.lr_scheduler,\n",
+ " _recursive_=False,\n",
+ " )\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "99f8a39f-8b10-4f7d-8bff-52794fd48717",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "<bound method WordPieceMapping.get_index of <text_recognizer.data.word_piece_mapping.WordPieceMapping object at 0x7fae3b489610>>"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "mapping.get_index"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": null,
"id": "af2c8cfa-0b45-4681-b671-0f97ace62516",
"metadata": {},
diff --git a/poetry.lock b/poetry.lock
index f8a6de3..76ea763 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -244,7 +244,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[[package]]
name = "charset-normalizer"
-version = "2.0.3"
+version = "2.0.4"
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
category = "main"
optional = false
@@ -658,7 +658,7 @@ test = ["pytest (!=5.3.4)", "pytest-cov", "flaky", "nose", "ipyparallel"]
[[package]]
name = "ipython"
-version = "7.25.0"
+version = "7.26.0"
description = "IPython: Productive Interactive Computing"
category = "dev"
optional = false
@@ -814,7 +814,7 @@ traitlets = "*"
[[package]]
name = "jupyter-server"
-version = "1.10.1"
+version = "1.10.2"
description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications."
category = "dev"
optional = false
@@ -876,7 +876,7 @@ pygments = ">=2.4.1,<3"
[[package]]
name = "jupyterlab-server"
-version = "2.6.1"
+version = "2.6.2"
description = "A set of server components for JupyterLab and JupyterLab like applications ."
category = "dev"
optional = false
@@ -1542,7 +1542,7 @@ six = ">=1.5"
[[package]]
name = "pytorch-lightning"
-version = "1.4.0"
+version = "1.4.1"
description = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate."
category = "main"
optional = false
@@ -1936,7 +1936,7 @@ python-versions = ">= 3.5"
[[package]]
name = "tqdm"
-version = "4.61.2"
+version = "4.62.0"
description = "Fast, Extensible Progress Meter"
category = "main"
optional = false
@@ -1994,7 +1994,7 @@ python-versions = "*"
[[package]]
name = "urllib3"
-version = "1.25.11"
+version = "1.26.6"
description = "HTTP library with thread-safe connection pooling, file post, and more."
category = "main"
optional = false
@@ -2007,11 +2007,11 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
[[package]]
name = "wandb"
-version = "0.10.33"
+version = "0.11.2"
description = "A CLI and library for interacting with the Weights and Biases API."
category = "dev"
optional = false
-python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+python-versions = ">=3.5"
[package.dependencies]
Click = ">=7.0,<8.0.0 || >8.0.0"
@@ -2025,11 +2025,11 @@ psutil = ">=5.0.0"
python-dateutil = ">=2.6.1"
PyYAML = "*"
requests = ">=2.0.0,<3"
-sentry-sdk = ">=0.4.0"
+sentry-sdk = ">=1.0.0"
shortuuid = ">=0.5.0"
six = ">=1.13.0"
subprocess32 = ">=3.5.3"
-urllib3 = {version = "<=1.25.11", markers = "sys_platform == \"win32\" or sys_platform == \"cygwin\""}
+urllib3 = ">=1.26.5"
[package.extras]
aws = ["boto3"]
@@ -2037,7 +2037,7 @@ gcp = ["google-cloud-storage"]
grpc = ["grpcio (==1.27.2)"]
kubeflow = ["kubernetes", "minio", "google-cloud-storage", "sh"]
media = ["numpy", "moviepy", "pillow", "bokeh", "soundfile", "plotly"]
-sweeps = ["numpy"]
+sweeps = ["numpy (>=1.15,<1.21)", "scipy (>=1.5.4)", "pyyaml", "scikit-learn (==0.24.1)", "jsonschema (>=3.2.0)", "jsonref (>=0.2)", "pydantic (>=1.8.2)"]
[[package]]
name = "wcwidth"
@@ -2111,7 +2111,7 @@ multidict = ">=4.0"
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
-content-hash = "91db4ec12db098a730fcdd63a7590cab62f0be1072c65229ae52fc35c58875a7"
+content-hash = "78ee5b3911c60380b6d7a487f61df807d2d623cdb8f9848dee260b581ec06460"
[metadata.files]
absl-py = [
@@ -2300,8 +2300,8 @@ chardet = [
{file = "chardet-4.0.0.tar.gz", hash = "sha256:0d6f53a15db4120f2b08c94f11e7d93d2c911ee118b6b30a04ec3ee8310179fa"},
]
charset-normalizer = [
- {file = "charset-normalizer-2.0.3.tar.gz", hash = "sha256:c46c3ace2d744cfbdebceaa3c19ae691f53ae621b39fd7570f59d14fb7f2fd12"},
- {file = "charset_normalizer-2.0.3-py3-none-any.whl", hash = "sha256:88fce3fa5b1a84fdcb3f603d889f723d1dd89b26059d0123ca435570e848d5e1"},
+ {file = "charset-normalizer-2.0.4.tar.gz", hash = "sha256:f23667ebe1084be45f6ae0538e4a5a865206544097e4e8bbcacf42cd02a348f3"},
+ {file = "charset_normalizer-2.0.4-py3-none-any.whl", hash = "sha256:0c8911edd15d19223366a194a513099a302055a962bca2cec0f54b8b63175d8b"},
]
click = [
{file = "click-7.1.2-py2.py3-none-any.whl", hash = "sha256:dacca89f4bfadd5de3d7489b7c8a566eee0d3676333fbb50030263894c38c0dc"},
@@ -2626,8 +2626,8 @@ ipykernel = [
{file = "ipykernel-6.0.3.tar.gz", hash = "sha256:0df34a78c7e1422800d6078cde65ccdcdb859597046c338c759db4dbc535c58f"},
]
ipython = [
- {file = "ipython-7.25.0-py3-none-any.whl", hash = "sha256:aa21412f2b04ad1a652e30564fff6b4de04726ce875eab222c8430edc6db383a"},
- {file = "ipython-7.25.0.tar.gz", hash = "sha256:54bbd1fe3882457aaf28ae060a5ccdef97f212a741754e420028d4ec5c2291dc"},
+ {file = "ipython-7.26.0-py3-none-any.whl", hash = "sha256:892743b65c21ed72b806a3a602cca408520b3200b89d1924f4b3d2cdb3692362"},
+ {file = "ipython-7.26.0.tar.gz", hash = "sha256:0cff04bb042800129348701f7bd68a430a844e8fb193979c08f6c99f28bb735e"},
]
ipython-genutils = [
{file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"},
@@ -2666,8 +2666,8 @@ jupyter-core = [
{file = "jupyter_core-4.7.1.tar.gz", hash = "sha256:79025cb3225efcd36847d0840f3fc672c0abd7afd0de83ba8a1d3837619122b4"},
]
jupyter-server = [
- {file = "jupyter_server-1.10.1-py3-none-any.whl", hash = "sha256:b3eef770ffa34595ed26a6e4460866eaf0f4ff710eccc7648f701bb8c1d0443c"},
- {file = "jupyter_server-1.10.1.tar.gz", hash = "sha256:fe6b589bd8d8fe08f608e90ce7da1e6bbfd020d99897453b45149a7853e9188f"},
+ {file = "jupyter_server-1.10.2-py3-none-any.whl", hash = "sha256:491c920013144a2d6f5286ab4038df6a081b32352c9c8b928ec8af17eb2a5e10"},
+ {file = "jupyter_server-1.10.2.tar.gz", hash = "sha256:d3a3b68ebc6d7bfee1097f1712cf7709ee39c92379da2cc08724515bb85e72bf"},
]
jupyterlab = [
{file = "jupyterlab-3.1.1-py3-none-any.whl", hash = "sha256:a181184b1000a550c38da35471dcf91ce11e96750de56430be3fc93ca01dde1e"},
@@ -2678,8 +2678,8 @@ jupyterlab-pygments = [
{file = "jupyterlab_pygments-0.1.2.tar.gz", hash = "sha256:cfcda0873626150932f438eccf0f8bf22bfa92345b814890ab360d666b254146"},
]
jupyterlab-server = [
- {file = "jupyterlab_server-2.6.1-py3-none-any.whl", hash = "sha256:58d4b660fce8da4e90f0433ac54f462436fe5fbe731e3a281e15adcdecddb0eb"},
- {file = "jupyterlab_server-2.6.1.tar.gz", hash = "sha256:73279d1ffdcd3426f716bf5538cf1fdd2eb8a340ac25c5688f3c192c5bd3afc9"},
+ {file = "jupyterlab_server-2.6.2-py3-none-any.whl", hash = "sha256:ab568da1dcef2ffdfc9161128dc00b931aae94d6a94978b16f55330dcd1cb043"},
+ {file = "jupyterlab_server-2.6.2.tar.gz", hash = "sha256:6dc6e7d26600d110b862acbfaa4d1a2c5e86781008d139213896d96178c3accd"},
]
jupyterlab-widgets = [
{file = "jupyterlab_widgets-1.0.0-py3-none-any.whl", hash = "sha256:caeaf3e6103180e654e7d8d2b81b7d645e59e432487c1d35a41d6d3ee56b3fef"},
@@ -3202,8 +3202,8 @@ python-dateutil = [
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
]
pytorch-lightning = [
- {file = "pytorch-lightning-1.4.0.tar.gz", hash = "sha256:6529cf064f9dc323c94f3ce84b56ee1a05db1b0ab17db77c4d15aa36e34da81f"},
- {file = "pytorch_lightning-1.4.0-py3-none-any.whl", hash = "sha256:41fb26e649b830019ecdffb6dc6558266e1317963f7bf2cddb1f1ed862245928"},
+ {file = "pytorch-lightning-1.4.1.tar.gz", hash = "sha256:1d1128aeb5d0e523d2204c4d9399d65c4e5f41ff0370e96d694a823af5e8e6f3"},
+ {file = "pytorch_lightning-1.4.1-py3-none-any.whl", hash = "sha256:4a06723a66296a2ac94cdf353335d64e7ae76c37202b2a4c38a845063e3fe386"},
]
pytz = [
{file = "pytz-2021.1-py2.py3-none-any.whl", hash = "sha256:eb10ce3e7736052ed3623d49975ce333bcd712c7bb19a58b9e2089d4057d0798"},
@@ -3569,8 +3569,8 @@ tornado = [
{file = "tornado-6.1.tar.gz", hash = "sha256:33c6e81d7bd55b468d2e793517c909b139960b6c790a60b7991b9b6b76fb9791"},
]
tqdm = [
- {file = "tqdm-4.61.2-py2.py3-none-any.whl", hash = "sha256:5aa445ea0ad8b16d82b15ab342de6b195a722d75fc1ef9934a46bba6feafbc64"},
- {file = "tqdm-4.61.2.tar.gz", hash = "sha256:8bb94db0d4468fea27d004a0f1d1c02da3cdedc00fe491c0de986b76a04d6b0a"},
+ {file = "tqdm-4.62.0-py2.py3-none-any.whl", hash = "sha256:706dea48ee05ba16e936ee91cb3791cd2ea6da348a0e50b46863ff4363ff4340"},
+ {file = "tqdm-4.62.0.tar.gz", hash = "sha256:3642d483b558eec80d3c831e23953582c34d7e4540db86d9e5ed9dad238dabc6"},
]
traitlets = [
{file = "traitlets-5.0.5-py3-none-any.whl", hash = "sha256:69ff3f9d5351f31a7ad80443c2674b7099df13cc41fc5fa6e2f6d3b0330b0426"},
@@ -3618,12 +3618,12 @@ typing-extensions = [
{file = "typing_extensions-3.10.0.0.tar.gz", hash = "sha256:50b6f157849174217d0656f99dc82fe932884fb250826c18350e159ec6cdf342"},
]
urllib3 = [
- {file = "urllib3-1.25.11-py2.py3-none-any.whl", hash = "sha256:f5321fbe4bf3fefa0efd0bfe7fb14e90909eb62a48ccda331726b4319897dd5e"},
- {file = "urllib3-1.25.11.tar.gz", hash = "sha256:8d7eaa5a82a1cac232164990f04874c594c9453ec55eef02eab885aa02fc17a2"},
+ {file = "urllib3-1.26.6-py2.py3-none-any.whl", hash = "sha256:39fb8672126159acb139a7718dd10806104dec1e2f0f6c88aab05d17df10c8d4"},
+ {file = "urllib3-1.26.6.tar.gz", hash = "sha256:f57b4c16c62fa2760b7e3d97c35b255512fb6b59a259730f36ba32ce9f8e342f"},
]
wandb = [
- {file = "wandb-0.10.33-py2.py3-none-any.whl", hash = "sha256:84f111e31cc4d6e95dcb62028c0c2a9fed7cdf0f8c563d86438aeadcf6d5f495"},
- {file = "wandb-0.10.33.tar.gz", hash = "sha256:ee69d4e251ae55e73d7d8b1a88b5629a588c820cce8dc8d5f5da15ac298556a7"},
+ {file = "wandb-0.11.2-py2.py3-none-any.whl", hash = "sha256:7bd00153873b0c1ceb31ae45852991bb08c1785f9c89d30dec0c569378ea3020"},
+ {file = "wandb-0.11.2.tar.gz", hash = "sha256:324ee38bcc1baea13cf914d5b28b21519237e17ab13dc7cac0870e0291930a2e"},
]
wcwidth = [
{file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"},
diff --git a/pyproject.toml b/pyproject.toml
index 6c5a2a0..7d81365 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,8 +15,6 @@ click = "^7.1.2"
boltons = "^20.1.0"
h5py = "^3.2.1"
toml = "^0.10.1"
-torch = "^1.9.0"
-torchvision = "^0.10.0"
loguru = "^0.5.0"
matplotlib = "^3.2.1"
tqdm = "^4.46.1"
@@ -24,7 +22,7 @@ opencv-python = "^4.3.0"
nltk = "^3.5"
torch-summary = "^1.4.2"
defusedxml = "^0.6.0"
-omegaconf = "^2.0.2"
+omegaconf = "^2.1.0"
einops = "^0.3.0"
gtn = "^0.0.0"
sentencepiece = "^0.1.95"
@@ -33,8 +31,10 @@ Pillow = "^8.1.2"
madgrad = "^1.0"
editdistance = "^0.5.3"
torchmetrics = "^0.4.1"
-hydra-core = "^1.0.6"
+hydra-core = "^1.1.0"
attr = "^0.3.1"
+torch = "^1.9.0"
+torchvision = "^0.10.0"
[tool.poetry.dev-dependencies]
pytest = "^5.4.2"
@@ -50,7 +50,7 @@ flake8-import-order = "^0.18.1"
safety = "^1.9.0"
mypy = "^0.770"
typeguard = "^2.7.1"
-wandb = "^0.10.30"
+wandb = "^0.11.2"
scipy = "^1.6.1"
flake8-annotations = "^2.6.2"
flake8-docstrings = "^1.6.0"
diff --git a/text_recognizer/criterions/label_smoothing.py b/text_recognizer/criterions/label_smoothing.py
index 40a7609..cc71c45 100644
--- a/text_recognizer/criterions/label_smoothing.py
+++ b/text_recognizer/criterions/label_smoothing.py
@@ -6,37 +6,31 @@ import torch.nn.functional as F
class LabelSmoothingLoss(nn.Module):
- """Label smoothing cross entropy loss."""
-
- def __init__(
- self, label_smoothing: float, vocab_size: int, ignore_index: int = -100
- ) -> None:
- assert 0.0 < label_smoothing <= 1.0
- self.ignore_index = ignore_index
+ def __init__(self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1):
super().__init__()
+ assert 0.0 < smoothing <= 1.0
+ self.ignore_index = ignore_index
+ self.confidence = 1.0 - smoothing
+ self.smoothing = smoothing
+ self.dim = dim
- smoothing_value = label_smoothing / (vocab_size - 2)
- one_hot = torch.full((vocab_size,), smoothing_value)
- one_hot[self.ignore_index] = 0
- self.register_buffer("one_hot", one_hot.unsqueeze(0))
-
- self.confidence = 1.0 - label_smoothing
-
- def forward(self, output: Tensor, targets: Tensor) -> Tensor:
+ def forward(self, output: Tensor, target: Tensor) -> Tensor:
"""Computes the loss.
Args:
- output (Tensor): Predictions from the network.
+ output (Tensor): outputictions from the network.
targets (Tensor): Ground truth.
Shapes:
- outpus: Batch size x num classes
- targets: Batch size
+ TBC
Returns:
Tensor: Label smoothing loss.
"""
- model_prob = self.one_hot.repeat(targets.size(0), 1)
- model_prob.scatter_(1, targets.unsqueeze(1), self.confidence)
- model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0)
- return F.kl_div(output, model_prob, reduction="sum")
+ output = output.log_softmax(dim=self.dim)
+ with torch.no_grad():
+ true_dist = torch.zeros_like(output)
+ true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
+ true_dist.masked_fill_((target == 4).unsqueeze(1), 0)
+ true_dist += self.smoothing / output.size(self.dim)
+ return torch.mean(torch.sum(-true_dist * output, dim=self.dim))
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index fd914b6..16a06d9 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,12 +1,12 @@
"""Base lightning DataModule class."""
from pathlib import Path
-from typing import Dict, Tuple
+from typing import Dict, Tuple, Type
import attr
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
-from text_recognizer.data.mappings import AbstractMapping
+from text_recognizer.data.base_mapping import AbstractMapping
from text_recognizer.data.base_dataset import BaseDataset
@@ -25,7 +25,7 @@ class BaseDataModule(LightningDataModule):
def __attrs_pre_init__(self) -> None:
super().__init__()
- mapping: AbstractMapping = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib()
batch_size: int = attr.ib(default=16)
num_workers: int = attr.ib(default=0)
pin_memory: bool = attr.ib(default=True)
diff --git a/text_recognizer/data/base_mapping.py b/text_recognizer/data/base_mapping.py
new file mode 100644
index 0000000..572ac95
--- /dev/null
+++ b/text_recognizer/data/base_mapping.py
@@ -0,0 +1,37 @@
+"""Mapping to and from word pieces."""
+from abc import ABC, abstractmethod
+from typing import Dict, List
+
+from torch import Tensor
+
+
+class AbstractMapping(ABC):
+ def __init__(
+ self, input_size: List[int], mapping: List[str], inverse_mapping: Dict[str, int]
+ ) -> None:
+ self.input_size = input_size
+ self.mapping = mapping
+ self.inverse_mapping = inverse_mapping
+
+ def __len__(self) -> int:
+ return len(self.mapping)
+
+ @property
+ def num_classes(self) -> int:
+ return self.__len__()
+
+ @abstractmethod
+ def get_token(self, *args, **kwargs) -> str:
+ ...
+
+ @abstractmethod
+ def get_index(self, *args, **kwargs) -> Tensor:
+ ...
+
+ @abstractmethod
+ def get_text(self, *args, **kwargs) -> str:
+ ...
+
+ @abstractmethod
+ def get_indices(self, *args, **kwargs) -> Tensor:
+ ...
diff --git a/text_recognizer/data/download_utils.py b/text_recognizer/data/download_utils.py
index 8938830..a5a5360 100644
--- a/text_recognizer/data/download_utils.py
+++ b/text_recognizer/data/download_utils.py
@@ -1,7 +1,7 @@
"""Util functions for downloading datasets."""
import hashlib
from pathlib import Path
-from typing import Dict, List, Optional
+from typing import Dict, Optional
from urllib.request import urlretrieve
from loguru import logger as log
diff --git a/text_recognizer/data/emnist_mapping.py b/text_recognizer/data/emnist_mapping.py
new file mode 100644
index 0000000..6c4c43b
--- /dev/null
+++ b/text_recognizer/data/emnist_mapping.py
@@ -0,0 +1,37 @@
+"""Emnist mapping."""
+from typing import List, Optional, Union, Set
+
+from torch import Tensor
+
+from text_recognizer.data.base_mapping import AbstractMapping
+from text_recognizer.data.emnist import emnist_mapping
+
+
+class EmnistMapping(AbstractMapping):
+ def __init__(self, extra_symbols: Optional[Set[str]] = None) -> None:
+ self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None
+ self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
+ self.extra_symbols
+ )
+ super().__init__(self.input_size, self.mapping, self.inverse_mapping)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+
+ def get_token(self, index: Union[int, Tensor]) -> str:
+ if (index := int(index)) in self.mapping:
+ return self.mapping[index]
+ raise KeyError(f"Index ({index}) not in mapping.")
+
+ def get_index(self, token: str) -> Tensor:
+ if token in self.inverse_mapping:
+ return Tensor(self.inverse_mapping[token])
+ raise KeyError(f"Token ({token}) not found in inverse mapping.")
+
+ def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ if isinstance(indices, Tensor):
+ indices = indices.tolist()
+ return "".join([self.mapping[index] for index in indices])
+
+ def get_indices(self, text: str) -> Tensor:
+ return Tensor([self.inverse_mapping[token] for token in text])
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index ccf0759..df0c0e1 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -1,6 +1,4 @@
"""IAM original and sythetic dataset class."""
-from typing import Dict, List
-
import attr
from torch.utils.data import ConcatDataset
@@ -15,7 +13,6 @@ class IAMExtendedParagraphs(BaseDataModule):
augment: bool = attr.ib(default=True)
train_fraction: float = attr.ib(default=0.8)
word_pieces: bool = attr.ib(default=False)
- num_classes: int = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
self.iam_paragraphs = IAMParagraphs(
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 1c63729..aba38f9 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -22,7 +22,7 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data import image_utils
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 6189f7d..11f899f 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -17,7 +17,7 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data.transforms import WordPiece
@@ -50,11 +50,9 @@ class IAMParagraphs(BaseDataModule):
if PROCESSED_DATA_DIRNAME.exists():
return
- log.info(
- "Cropping IAM paragraph regions and saving them along with labels..."
- )
+ log.info("Cropping IAM paragraph regions and saving them along with labels...")
- iam = IAM(mapping=EmnistMapping())
+ iam = IAM(mapping=EmnistMapping(extra_symbols={NEW_LINE_TOKEN,}))
iam.prepare_data()
properties = {}
@@ -83,7 +81,9 @@ class IAMParagraphs(BaseDataModule):
crops, labels = _load_processed_crops_and_labels(split)
data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
targets = convert_strings_to_labels(
- strings=labels, mapping=self.mapping.inverse_mapping, length=self.output_dims[0]
+ strings=labels,
+ mapping=self.mapping.inverse_mapping,
+ length=self.output_dims[0],
)
return BaseDataset(
data,
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index c938f8b..24ca896 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -21,7 +21,7 @@ from text_recognizer.data.iam_paragraphs import (
IMAGE_SCALE_FACTOR,
resize_image,
)
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data.iam_lines import (
line_crops_and_labels,
@@ -47,7 +47,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
log.info("Preparing IAM lines for synthetic paragraphs dataset.")
log.info("Cropping IAM line regions and loading labels.")
- iam = IAM(mapping=EmnistMapping())
+ iam = IAM(mapping=EmnistMapping(extra_symbols={NEW_LINE_TOKEN,}))
iam.prepare_data()
crops_train, labels_train = line_crops_and_labels(iam, "train")
diff --git a/text_recognizer/data/make_wordpieces.py b/text_recognizer/data/make_wordpieces.py
index 40fbee4..8e53815 100644
--- a/text_recognizer/data/make_wordpieces.py
+++ b/text_recognizer/data/make_wordpieces.py
@@ -13,8 +13,6 @@ import click
from loguru import logger as log
import sentencepiece as spm
-from text_recognizer.data.iam_preprocessor import load_metadata
-
def iamdb_pieces(
data_dir: Path, text_file: str, num_pieces: int, output_prefix: str
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
deleted file mode 100644
index d1c64dd..0000000
--- a/text_recognizer/data/mappings.py
+++ /dev/null
@@ -1,156 +0,0 @@
-"""Mapping to and from word pieces."""
-from abc import ABC, abstractmethod
-from pathlib import Path
-from typing import Dict, List, Optional, Union, Set
-
-import attr
-import torch
-from loguru import logger as log
-from torch import Tensor
-
-from text_recognizer.data.emnist import emnist_mapping
-from text_recognizer.data.iam_preprocessor import Preprocessor
-
-
-@attr.s
-class AbstractMapping(ABC):
- input_size: List[int] = attr.ib(init=False)
- mapping: List[str] = attr.ib(init=False)
- inverse_mapping: Dict[str, int] = attr.ib(init=False)
-
- def __len__(self) -> int:
- return len(self.mapping)
-
- @property
- def num_classes(self) -> int:
- return self.__len__()
-
- @abstractmethod
- def get_token(self, *args, **kwargs) -> str:
- ...
-
- @abstractmethod
- def get_index(self, *args, **kwargs) -> Tensor:
- ...
-
- @abstractmethod
- def get_text(self, *args, **kwargs) -> str:
- ...
-
- @abstractmethod
- def get_indices(self, *args, **kwargs) -> Tensor:
- ...
-
-
-@attr.s(auto_attribs=True)
-class EmnistMapping(AbstractMapping):
- extra_symbols: Optional[Set[str]] = attr.ib(default=None)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
- self.extra_symbols = set(self.extra_symbols) if self.extra_symbols is not None else None
- self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
- self.extra_symbols
- )
-
- def get_token(self, index: Union[int, Tensor]) -> str:
- if (index := int(index)) in self.mapping:
- return self.mapping[index]
- raise KeyError(f"Index ({index}) not in mapping.")
-
- def get_index(self, token: str) -> Tensor:
- if token in self.inverse_mapping:
- return Tensor(self.inverse_mapping[token])
- raise KeyError(f"Token ({token}) not found in inverse mapping.")
-
- def get_text(self, indices: Union[List[int], Tensor]) -> str:
- if isinstance(indices, Tensor):
- indices = indices.tolist()
- return "".join([self.mapping[index] for index in indices])
-
- def get_indices(self, text: str) -> Tensor:
- return Tensor([self.inverse_mapping[token] for token in text])
-
-
-@attr.s(auto_attribs=True)
-class WordPieceMapping(EmnistMapping):
- data_dir: Optional[Path] = attr.ib(default=None)
- num_features: int = attr.ib(default=1000)
- tokens: str = attr.ib(default="iamdb_1kwp_tokens_1000.txt")
- lexicon: str = attr.ib(default="iamdb_1kwp_lex_1000.txt")
- use_words: bool = attr.ib(default=False)
- prepend_wordsep: bool = attr.ib(default=False)
- special_tokens: Set[str] = attr.ib(default={"<s>", "<e>", "<p>"}, converter=set)
- extra_symbols: Set[str] = attr.ib(default={"\n",}, converter=set)
- wordpiece_processor: Preprocessor = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- super().__attrs_post_init__()
- self.data_dir = (
- (
- Path(__file__).resolve().parents[2]
- / "data"
- / "downloaded"
- / "iam"
- / "iamdb"
- )
- if self.data_dir is None
- else Path(self.data_dir)
- )
- log.debug(f"Using data dir: {self.data_dir}")
- if not self.data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}")
-
- processed_path = (
- Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
- )
-
- tokens_path = processed_path / self.tokens
- lexicon_path = processed_path / self.lexicon
-
- special_tokens = self.special_tokens
- if self.extra_symbols is not None:
- special_tokens = special_tokens | self.extra_symbols
-
- self.wordpiece_processor = Preprocessor(
- data_dir=self.data_dir,
- num_features=self.num_features,
- tokens_path=tokens_path,
- lexicon_path=lexicon_path,
- use_words=self.use_words,
- prepend_wordsep=self.prepend_wordsep,
- special_tokens=special_tokens,
- )
-
- def __len__(self) -> int:
- return len(self.wordpiece_processor.tokens)
-
- def get_token(self, index: Union[int, Tensor]) -> str:
- if (index := int(index)) <= self.wordpiece_processor.num_tokens:
- return self.wordpiece_processor.tokens[index]
- raise KeyError(f"Index ({index}) not in mapping.")
-
- def get_index(self, token: str) -> Tensor:
- if token in self.wordpiece_processor.tokens:
- return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]])
- raise KeyError(f"Token ({token}) not found in inverse mapping.")
-
- def get_text(self, indices: Union[List[int], Tensor]) -> str:
- if isinstance(indices, Tensor):
- indices = indices.tolist()
- return self.wordpiece_processor.to_text(indices)
-
- def get_indices(self, text: str) -> Tensor:
- return self.wordpiece_processor.to_index(text)
-
- def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor:
- text = "".join([self.mapping[i] for i in x])
- text = text.lower().replace(" ", "▁")
- return torch.LongTensor(self.wordpiece_processor.to_index(text))
-
- def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]:
- if isinstance(x, int):
- x = [x]
- if isinstance(x, str):
- return self.get_indices(x)
- return self.get_text(x)
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index 3b1b929..047496f 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -1,11 +1,11 @@
"""Transforms for PyTorch datasets."""
from pathlib import Path
-from typing import Optional, Union, Sequence
+from typing import Optional, Union, Set
import torch
from torch import Tensor
-from text_recognizer.data.mappings import WordPieceMapping
+from text_recognizer.data.word_piece_mapping import WordPieceMapping
class WordPiece:
@@ -19,8 +19,8 @@ class WordPiece:
data_dir: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
- special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
- extra_symbols: Optional[Sequence[str]] = ("\n",),
+ special_tokens: Set[str] = {"<s>", "<e>", "<p>"},
+ extra_symbols: Optional[Set[str]] = {"\n",},
max_len: int = 451,
) -> None:
self.mapping = WordPieceMapping(
diff --git a/text_recognizer/data/word_piece_mapping.py b/text_recognizer/data/word_piece_mapping.py
new file mode 100644
index 0000000..59488c3
--- /dev/null
+++ b/text_recognizer/data/word_piece_mapping.py
@@ -0,0 +1,93 @@
+"""Word piece mapping."""
+from pathlib import Path
+from typing import List, Optional, Union, Set
+
+import torch
+from loguru import logger as log
+from torch import Tensor
+
+from text_recognizer.data.emnist_mapping import EmnistMapping
+from text_recognizer.data.iam_preprocessor import Preprocessor
+
+
+class WordPieceMapping(EmnistMapping):
+ def __init__(
+ self,
+ data_dir: Optional[Path] = None,
+ num_features: int = 1000,
+ tokens: str = "iamdb_1kwp_tokens_1000.txt",
+ lexicon: str = "iamdb_1kwp_lex_1000.txt",
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ special_tokens: Set[str] = {"<s>", "<e>", "<p>"},
+ extra_symbols: Set[str] = {"\n",},
+ ) -> None:
+ super().__init__(extra_symbols=extra_symbols)
+ self.data_dir = (
+ (
+ Path(__file__).resolve().parents[2]
+ / "data"
+ / "downloaded"
+ / "iam"
+ / "iamdb"
+ )
+ if data_dir is None
+ else Path(data_dir)
+ )
+ log.debug(f"Using data dir: {self.data_dir}")
+ if not self.data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}")
+
+ processed_path = (
+ Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
+ )
+
+ tokens_path = processed_path / tokens
+ lexicon_path = processed_path / lexicon
+
+ special_tokens = set(special_tokens)
+ if self.extra_symbols is not None:
+ special_tokens = special_tokens | set(extra_symbols)
+
+ self.wordpiece_processor = Preprocessor(
+ data_dir=self.data_dir,
+ num_features=num_features,
+ tokens_path=tokens_path,
+ lexicon_path=lexicon_path,
+ use_words=use_words,
+ prepend_wordsep=prepend_wordsep,
+ special_tokens=special_tokens,
+ )
+
+ def __len__(self) -> int:
+ return len(self.wordpiece_processor.tokens)
+
+ def get_token(self, index: Union[int, Tensor]) -> str:
+ if (index := int(index)) <= self.wordpiece_processor.num_tokens:
+ return self.wordpiece_processor.tokens[index]
+ raise KeyError(f"Index ({index}) not in mapping.")
+
+ def get_index(self, token: str) -> Tensor:
+ if token in self.wordpiece_processor.tokens:
+ return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]])
+ raise KeyError(f"Token ({token}) not found in inverse mapping.")
+
+ def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ if isinstance(indices, Tensor):
+ indices = indices.tolist()
+ return self.wordpiece_processor.to_text(indices).replace(" ", "▁")
+
+ def get_indices(self, text: str) -> Tensor:
+ return self.wordpiece_processor.to_index(text)
+
+ def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor:
+ text = "".join([self.mapping[i] for i in x])
+ text = text.lower().replace(" ", "▁")
+ return torch.LongTensor(self.wordpiece_processor.to_index(text))
+
+ def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]:
+ if isinstance(x, int):
+ x = [x]
+ if isinstance(x, str):
+ return self.get_indices(x)
+ return self.get_text(x)
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 8ce5c37..57c5964 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -11,6 +11,8 @@ from torch import nn
from torch import Tensor
import torchmetrics
+from text_recognizer.data.base_mapping import AbstractMapping
+
@attr.s(eq=False)
class BaseLitModel(LightningModule):
@@ -20,12 +22,12 @@ class BaseLitModel(LightningModule):
super().__init__()
network: Type[nn.Module] = attr.ib()
- criterion_config: DictConfig = attr.ib(converter=DictConfig)
- optimizer_config: DictConfig = attr.ib(converter=DictConfig)
- lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
+ mapping: Type[AbstractMapping] = attr.ib()
+ loss_fn: Type[nn.Module] = attr.ib()
+ optimizer_config: DictConfig = attr.ib()
+ lr_scheduler_config: DictConfig = attr.ib()
interval: str = attr.ib()
monitor: str = attr.ib(default="val/loss")
- loss_fn: Type[nn.Module] = attr.ib(init=False)
train_acc: torchmetrics.Accuracy = attr.ib(
init=False, default=torchmetrics.Accuracy()
)
@@ -36,12 +38,6 @@ class BaseLitModel(LightningModule):
init=False, default=torchmetrics.Accuracy()
)
- @loss_fn.default
- def configure_criterion(self) -> Type[nn.Module]:
- """Returns a loss functions."""
- log.info(f"Instantiating criterion <{self.criterion_config._target_}>")
- return hydra.utils.instantiate(self.criterion_config)
-
def optimizer_zero_grad(
self,
epoch: int,
@@ -54,7 +50,9 @@ class BaseLitModel(LightningModule):
def _configure_optimizer(self) -> Type[torch.optim.Optimizer]:
"""Configures the optimizer."""
log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>")
- return hydra.utils.instantiate(self.optimizer_config, params=self.parameters())
+ return hydra.utils.instantiate(
+ self.optimizer_config, params=self.network.parameters()
+ )
def _configure_lr_scheduler(
self, optimizer: Type[torch.optim.Optimizer]
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 91e088d..5fb84a7 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -5,7 +5,6 @@ import attr
import torch
from torch import Tensor
-from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
@@ -14,14 +13,14 @@ from text_recognizer.models.base import BaseLitModel
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- mapping: Type[AbstractMapping] = attr.ib(default=None)
+ max_output_len: int = attr.ib(default=451)
start_token: str = attr.ib(default="<s>")
end_token: str = attr.ib(default="<e>")
pad_token: str = attr.ib(default="<p>")
- start_index: Tensor = attr.ib(init=False)
- end_index: Tensor = attr.ib(init=False)
- pad_index: Tensor = attr.ib(init=False)
+ start_index: int = attr.ib(init=False)
+ end_index: int = attr.ib(init=False)
+ pad_index: int = attr.ib(init=False)
ignore_indices: Set[Tensor] = attr.ib(init=False)
val_cer: CharacterErrorRate = attr.ib(init=False)
@@ -29,9 +28,9 @@ class TransformerLitModel(BaseLitModel):
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
- self.start_index = self.mapping.get_index(self.start_token)
- self.end_index = self.mapping.get_index(self.end_token)
- self.pad_index = self.mapping.get_index(self.pad_token)
+ self.start_index = int(self.mapping.get_index(self.start_token))
+ self.end_index = int(self.mapping.get_index(self.end_token))
+ self.pad_index = int(self.mapping.get_index(self.pad_token))
self.ignore_indices = set([self.start_index, self.end_index, self.pad_index])
self.val_cer = CharacterErrorRate(self.ignore_indices)
self.test_cer = CharacterErrorRate(self.ignore_indices)
@@ -93,23 +92,24 @@ class TransformerLitModel(BaseLitModel):
output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
output[:, 0] = self.start_index
- for i in range(1, self.max_output_len):
- context = output[:, :i] # (bsz, i)
- logits = self.network.decode(z, context) # (i, bsz, c)
- tokens = torch.argmax(logits, dim=-1) # (i, bsz)
- output[:, i : i + 1] = tokens[-1:]
+ for Sy in range(1, self.max_output_len):
+ context = output[:, :Sy] # (B, Sy)
+ logits = self.network.decode(z, context) # (B, Sy, C)
+ tokens = torch.argmax(logits, dim=-1) # (B, Sy)
+ output[:, Sy : Sy + 1] = tokens[:, -1:]
# Early stopping of prediction loop if token is end or padding token.
if (
- output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index
+ (output[:, Sy - 1] == self.end_index)
+ | (output[:, Sy - 1] == self.pad_index)
).all():
break
# Set all tokens after end token to pad token.
- for i in range(1, self.max_output_len):
- idx = (
- output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index
+ for Sy in range(1, self.max_output_len):
+ idx = (output[:, Sy - 1] == self.end_index) | (
+ output[:, Sy - 1] == self.pad_index
)
- output[idx, i] = self.pad_index
+ output[idx, Sy] = self.pad_index
return output
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 09cc654..f3ba49d 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -2,7 +2,6 @@
import math
from typing import Tuple
-import attr
from torch import nn, Tensor
from text_recognizer.networks.encoders.efficientnet import EfficientNet
@@ -13,32 +12,28 @@ from text_recognizer.networks.transformer.positional_encodings import (
)
-@attr.s(eq=False)
class ConvTransformer(nn.Module):
"""Convolutional encoder and transformer decoder network."""
- def __attrs_pre_init__(self) -> None:
+ def __init__(
+ self,
+ input_dims: Tuple[int, int, int],
+ hidden_dim: int,
+ dropout_rate: float,
+ num_classes: int,
+ pad_index: Tensor,
+ encoder: EfficientNet,
+ decoder: Decoder,
+ ) -> None:
super().__init__()
+ self.input_dims = input_dims
+ self.hidden_dim = hidden_dim
+ self.dropout_rate = dropout_rate
+ self.num_classes = num_classes
+ self.pad_index = pad_index
+ self.encoder = encoder
+ self.decoder = decoder
- # Parameters and placeholders,
- input_dims: Tuple[int, int, int] = attr.ib()
- hidden_dim: int = attr.ib()
- dropout_rate: float = attr.ib()
- max_output_len: int = attr.ib()
- num_classes: int = attr.ib()
- pad_index: Tensor = attr.ib()
-
- # Modules.
- encoder: EfficientNet = attr.ib()
- decoder: Decoder = attr.ib()
-
- latent_encoder: nn.Sequential = attr.ib(init=False)
- token_embedding: nn.Embedding = attr.ib(init=False)
- token_pos_encoder: PositionalEncoding = attr.ib(init=False)
- head: nn.Linear = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
# Latent projector for down sampling number of filters and 2d
# positional encoding.
self.latent_encoder = nn.Sequential(
@@ -126,7 +121,8 @@ class ConvTransformer(nn.Module):
context = self.token_embedding(context) * math.sqrt(self.hidden_dim)
context = self.token_pos_encoder(context)
out = self.decoder(x=context, context=z, mask=context_mask)
- logits = self.head(out)
+ logits = self.head(out) # [B, Sy, T]
+ logits = logits.permute(0, 2, 1) # [B, T, Sy]
return logits
def forward(self, x: Tensor, context: Tensor) -> Tensor:
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py
index e85df87..7bfd9ba 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py
@@ -11,9 +11,7 @@ from text_recognizer.networks.encoders.efficientnet.utils import stochastic_dept
def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
"""Converts int to tuple."""
- return (
- (stride,) * 2 if isinstance(stride, int) else stride
- )
+ return (stride,) * 2 if isinstance(stride, int) else stride
@attr.s(eq=False)
@@ -41,10 +39,7 @@ class MBConvBlock(nn.Module):
def _configure_padding(self) -> Tuple[int, int, int, int]:
"""Set padding for convolutional layers."""
if self.stride == (2, 2):
- return (
- (self.kernel_size - 1) // 2 - 1,
- (self.kernel_size - 1) // 2,
- ) * 2
+ return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2
return ((self.kernel_size - 1) // 2,) * 4
def __attrs_post_init__(self) -> None:
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index ce443e5..70a0ac7 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -1,5 +1,4 @@
"""Transformer attention layer."""
-from functools import partial
from typing import Any, Dict, Optional, Tuple
import attr
@@ -27,25 +26,17 @@ class AttentionLayers(nn.Module):
norm_fn: str = attr.ib()
ff_fn: str = attr.ib()
ff_kwargs: Dict = attr.ib()
+ rotary_emb: Optional[RotaryEmbedding] = attr.ib()
causal: bool = attr.ib(default=False)
cross_attend: bool = attr.ib(default=False)
pre_norm: bool = attr.ib(default=True)
- rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None)
layer_types: Tuple[str, ...] = attr.ib(init=False)
layers: nn.ModuleList = attr.ib(init=False)
- attn: partial = attr.ib(init=False)
- norm: partial = attr.ib(init=False)
- ff: partial = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
self.layer_types = self._get_layer_types() * self.depth
- attn = load_partial_fn(
- self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
- )
- norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim)
- ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
- self.layers = self._build_network(attn, norm, ff)
+ self.layers = self._build_network()
def _get_layer_types(self) -> Tuple:
"""Get layer specification."""
@@ -53,10 +44,13 @@ class AttentionLayers(nn.Module):
return "a", "c", "f"
return "a", "f"
- def _build_network(
- self, attn: partial, norm: partial, ff: partial,
- ) -> nn.ModuleList:
+ def _build_network(self) -> nn.ModuleList:
"""Configures transformer network."""
+ attn = load_partial_fn(
+ self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
+ )
+ norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim)
+ ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
layers = nn.ModuleList([])
for layer_type in self.layer_types:
if layer_type == "a":
@@ -106,6 +100,7 @@ class Encoder(AttentionLayers):
causal: bool = attr.ib(default=False, init=False)
-@attr.s(auto_attribs=True, eq=False)
class Decoder(AttentionLayers):
- causal: bool = attr.ib(default=True, init=False)
+ def __init__(self, **kwargs: Any) -> None:
+ assert "causal" not in kwargs, "Cannot set causality on decoder"
+ super().__init__(causal=True, **kwargs)
diff --git a/training/conf/callbacks/wandb/image_reconstructions.yaml b/training/__init__.py
index e69de29..e69de29 100644
--- a/training/conf/callbacks/wandb/image_reconstructions.yaml
+++ b/training/__init__.py
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index 6379cc0..906531f 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -1,11 +1,10 @@
"""Weights and Biases callbacks."""
from pathlib import Path
-from typing import List
-import attr
import wandb
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.loggers import LoggerCollection, WandbLogger
+from pytorch_lightning.utilities import rank_zero_only
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
@@ -22,31 +21,27 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger:
raise Exception("Weight and Biases logger not found for some reason...")
-@attr.s
class WatchModel(Callback):
"""Make W&B watch the model at the beginning of the run."""
- log: str = attr.ib(default="gradients")
- log_freq: int = attr.ib(default=100)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
+ def __init__(self, log: str = "gradients", log_freq: int = 100) -> None:
+ self.log = log
+ self.log_freq = log_freq
+ @rank_zero_only
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Watches model weights with wandb."""
logger = get_wandb_logger(trainer)
logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq)
-@attr.s
class UploadCodeAsArtifact(Callback):
"""Upload all *.py files to W&B as an artifact, at the beginning of the run."""
- project_dir: Path = attr.ib(converter=Path)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
+ def __init__(self, project_dir: str) -> None:
+ self.project_dir = Path(project_dir)
+ @rank_zero_only
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Uploads project code as an artifact."""
logger = get_wandb_logger(trainer)
@@ -58,16 +53,16 @@ class UploadCodeAsArtifact(Callback):
experiment.use_artifact(artifact)
-@attr.s
-class UploadCheckpointAsArtifact(Callback):
+class UploadCheckpointsAsArtifact(Callback):
"""Upload checkpoint to wandb as an artifact, at the end of a run."""
- ckpt_dir: Path = attr.ib(converter=Path)
- upload_best_only: bool = attr.ib()
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
+ def __init__(
+ self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False
+ ) -> None:
+ self.ckpt_dir = ckpt_dir
+ self.upload_best_only = upload_best_only
+ @rank_zero_only
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Uploads model checkpoint to W&B."""
logger = get_wandb_logger(trainer)
@@ -83,15 +78,12 @@ class UploadCheckpointAsArtifact(Callback):
experiment.use_artifact(ckpts)
-@attr.s
class LogTextPredictions(Callback):
"""Logs a validation batch with image to text transcription."""
- num_samples: int = attr.ib(default=8)
- ready: bool = attr.ib(default=True)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
+ def __init__(self, num_samples: int = 8) -> None:
+ self.num_samples = num_samples
+ self.ready = False
def _log_predictions(
self, stage: str, trainer: Trainer, pl_module: LightningModule
@@ -111,20 +103,20 @@ class LogTextPredictions(Callback):
logits = pl_module(imgs)
mapping = pl_module.mapping
+ columns = ["id", "image", "prediction", "truth"]
+ data = [
+ [id, wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)]
+ for id, (img, pred, label) in enumerate(
+ zip(
+ imgs[: self.num_samples],
+ logits[: self.num_samples],
+ labels[: self.num_samples],
+ )
+ )
+ ]
+
experiment.log(
- {
- f"OCR/{experiment.name}/{stage}": [
- wandb.Image(
- img,
- caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}",
- )
- for img, pred, label in zip(
- imgs[: self.num_samples],
- logits[: self.num_samples],
- labels[: self.num_samples],
- )
- ]
- }
+ {f"OCR/{experiment.name}/{stage}": wandb.Table(data=data, columns=columns)}
)
def on_sanity_check_start(
@@ -143,20 +135,17 @@ class LogTextPredictions(Callback):
"""Logs predictions on validation epoch end."""
self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module)
- def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Logs predictions on train epoch end."""
self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module)
-@attr.s
class LogReconstuctedImages(Callback):
"""Log reconstructions of images."""
- num_samples: int = attr.ib(default=8)
- ready: bool = attr.ib(default=True)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
+ def __init__(self, num_samples: int = 8) -> None:
+ self.num_samples = num_samples
+ self.ready = False
def _log_reconstruction(
self, stage: str, trainer: Trainer, pl_module: LightningModule
@@ -202,6 +191,6 @@ class LogReconstuctedImages(Callback):
"""Logs predictions on validation epoch end."""
self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module)
- def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Logs predictions on train epoch end."""
self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module)
diff --git a/training/conf/callbacks/checkpoint.yaml b/training/conf/callbacks/checkpoint.yaml
index db34cb1..b4101d8 100644
--- a/training/conf/callbacks/checkpoint.yaml
+++ b/training/conf/callbacks/checkpoint.yaml
@@ -6,4 +6,4 @@ model_checkpoint:
mode: min # can be "max" or "min"
verbose: false
dirpath: checkpoints/
- filename: {epoch:02d}
+ filename: "{epoch:02d}"
diff --git a/training/conf/callbacks/wandb/checkpoints.yaml b/training/conf/callbacks/wandb_checkpoints.yaml
index a4a16ff..a4a16ff 100644
--- a/training/conf/callbacks/wandb/checkpoints.yaml
+++ b/training/conf/callbacks/wandb_checkpoints.yaml
diff --git a/training/conf/callbacks/wandb/code.yaml b/training/conf/callbacks/wandb_code.yaml
index 35f6ea3..35f6ea3 100644
--- a/training/conf/callbacks/wandb/code.yaml
+++ b/training/conf/callbacks/wandb_code.yaml
diff --git a/training/conf/callbacks/wandb_image_reconstructions.yaml b/training/conf/callbacks/wandb_image_reconstructions.yaml
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/training/conf/callbacks/wandb_image_reconstructions.yaml
diff --git a/training/conf/callbacks/wandb_ocr.yaml b/training/conf/callbacks/wandb_ocr.yaml
index efa3dda..9c9a6da 100644
--- a/training/conf/callbacks/wandb_ocr.yaml
+++ b/training/conf/callbacks/wandb_ocr.yaml
@@ -1,6 +1,6 @@
defaults:
- default
- - wandb/watch
- - wandb/code
- - wandb/checkpoints
- - wandb/ocr_predictions
+ - wandb_watch
+ - wandb_code
+ - wandb_checkpoints
+ - wandb_ocr_predictions
diff --git a/training/conf/callbacks/wandb/ocr_predictions.yaml b/training/conf/callbacks/wandb_ocr_predictions.yaml
index 573fa96..573fa96 100644
--- a/training/conf/callbacks/wandb/ocr_predictions.yaml
+++ b/training/conf/callbacks/wandb_ocr_predictions.yaml
diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb_watch.yaml
index 511608c..511608c 100644
--- a/training/conf/callbacks/wandb/watch.yaml
+++ b/training/conf/callbacks/wandb_watch.yaml
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index 93215ed..782bcbb 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -1,8 +1,9 @@
defaults:
- callbacks: wandb_ocr
- criterion: label_smoothing
- - dataset: iam_extended_paragraphs
+ - datamodule: iam_extended_paragraphs
- hydra: default
+ - logger: wandb
- lr_scheduler: one_cycle
- mapping: word_piece
- model: lit_transformer
@@ -15,3 +16,21 @@ tune: false
train: true
test: true
logging: INFO
+
+# path to original working directory
+# hydra hijacks working directory by changing it to the current log directory,
+# so it's useful to have this path as a special variable
+# learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
+work_dir: ${hydra:runtime.cwd}
+
+# use `python run.py debug=true` for easy debugging!
+# this will run 1 train, val and test loop with only 1 batch
+# equivalent to running `python run.py trainer.fast_dev_run=true`
+# (this is placed here just for easier access from command line)
+debug: False
+
+# pretty print config at the start of the run using Rich library
+print_config: True
+
+# disable python warnings if they annoy you
+ignore_warnings: True
diff --git a/training/conf/criterion/label_smoothing.yaml b/training/conf/criterion/label_smoothing.yaml
index 13daba8..684b5bb 100644
--- a/training/conf/criterion/label_smoothing.yaml
+++ b/training/conf/criterion/label_smoothing.yaml
@@ -1,4 +1,3 @@
-_target_: text_recognizer.criterion.label_smoothing.LabelSmoothingLoss
-label_smoothing: 0.1
-vocab_size: 1006
+_target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss
+smoothing: 0.1
ignore_index: 1002
diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml
index 3070b56..2d1a03e 100644
--- a/training/conf/datamodule/iam_extended_paragraphs.yaml
+++ b/training/conf/datamodule/iam_extended_paragraphs.yaml
@@ -1,5 +1,6 @@
_target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs
-batch_size: 32
+batch_size: 4
num_workers: 12
train_fraction: 0.8
augment: true
+pin_memory: false
diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml
index 5afdf81..eecee8a 100644
--- a/training/conf/lr_scheduler/one_cycle.yaml
+++ b/training/conf/lr_scheduler/one_cycle.yaml
@@ -1,8 +1,8 @@
_target_: torch.optim.lr_scheduler.OneCycleLR
max_lr: 1.0e-3
total_steps: null
-epochs: null
-steps_per_epoch: null
+epochs: 512
+steps_per_epoch: 4992
pct_start: 0.3
anneal_strategy: cos
cycle_momentum: true
diff --git a/training/conf/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml
index 3792523..48384f5 100644
--- a/training/conf/mapping/word_piece.yaml
+++ b/training/conf/mapping/word_piece.yaml
@@ -1,4 +1,4 @@
-_target_: text_recognizer.data.mappings.WordPieceMapping
+_target_: text_recognizer.data.word_piece_mapping.WordPieceMapping
num_features: 1000
tokens: iamdb_1kwp_tokens_1000.txt
lexicon: iamdb_1kwp_lex_1000.txt
@@ -6,4 +6,4 @@ data_dir: null
use_words: false
prepend_wordsep: false
special_tokens: [ <s>, <e>, <p> ]
-extra_symbols: [ \n ]
+extra_symbols: [ "\n" ]
diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml
index 6ffde4e..c190151 100644
--- a/training/conf/model/lit_transformer.yaml
+++ b/training/conf/model/lit_transformer.yaml
@@ -1,7 +1,7 @@
_target_: text_recognizer.models.transformer.TransformerLitModel
interval: step
monitor: val/loss
-ignore_tokens: [ <s>, <e>, <p> ]
+max_output_len: 451
start_token: <s>
end_token: <e>
pad_token: <p>
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index a97157d..f76e892 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -6,6 +6,5 @@ _target_: text_recognizer.networks.conv_transformer.ConvTransformer
input_dims: [1, 576, 640]
hidden_dim: 96
dropout_rate: 0.2
-max_output_len: 451
num_classes: 1006
pad_index: 1002
diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml
index 90b9d8a..eb80f64 100644
--- a/training/conf/network/decoder/transformer_decoder.yaml
+++ b/training/conf/network/decoder/transformer_decoder.yaml
@@ -18,3 +18,4 @@ ff_kwargs:
dropout_rate: 0.2
cross_attend: true
pre_norm: true
+rotary_emb: null
diff --git a/training/run.py b/training/run.py
index 30479c6..13a6a82 100644
--- a/training/run.py
+++ b/training/run.py
@@ -12,35 +12,40 @@ from pytorch_lightning import (
Trainer,
)
from pytorch_lightning.loggers import LightningLoggerBase
-from text_recognizer.data.mappings import AbstractMapping
from torch import nn
+from text_recognizer.data.base_mapping import AbstractMapping
import utils
def run(config: DictConfig) -> Optional[float]:
"""Runs experiment."""
- utils.configure_logging(config.logging)
+ utils.configure_logging(config)
log.info("Starting experiment...")
if config.get("seed"):
- seed_everything(config.seed)
+ seed_everything(config.seed, workers=True)
log.info(f"Instantiating mapping <{config.mapping._target_}>")
mapping: AbstractMapping = hydra.utils.instantiate(config.mapping)
log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
- datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule, mapping=mapping)
+ datamodule: LightningDataModule = hydra.utils.instantiate(
+ config.datamodule, mapping=mapping
+ )
log.info(f"Instantiating network <{config.network._target_}>")
network: nn.Module = hydra.utils.instantiate(config.network)
+ log.info(f"Instantiating criterion <{config.criterion._target_}>")
+ loss_fn: Type[nn.Module] = hydra.utils.instantiate(config.criterion)
+
log.info(f"Instantiating model <{config.model._target_}>")
model: LightningModule = hydra.utils.instantiate(
- **config.model,
+ config.model,
mapping=mapping,
network=network,
- criterion_config=config.criterion,
+ loss_fn=loss_fn,
optimizer_config=config.optimizer,
lr_scheduler_config=config.lr_scheduler,
_recursive_=False,
@@ -77,4 +82,4 @@ def run(config: DictConfig) -> Optional[float]:
trainer.test(model, datamodule=datamodule)
log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}")
- utils.finish(trainer)
+ utils.finish(logger)
diff --git a/training/utils.py b/training/utils.py
index ef74f61..d23396e 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -17,6 +17,10 @@ from tqdm import tqdm
import wandb
+def print_config(config: DictConfig) -> None:
+ print(OmegaConf.to_yaml(config))
+
+
@rank_zero_only
def configure_logging(config: DictConfig) -> None:
"""Configure the loguru logger for output to terminal and disk."""
@@ -30,7 +34,7 @@ def configure_callbacks(config: DictConfig,) -> List[Type[Callback]]:
callbacks = []
if config.get("callbacks"):
for callback_config in config.callbacks.values():
- if config.get("_target_"):
+ if callback_config.get("_target_"):
log.info(f"Instantiating callback <{callback_config._target_}>")
callbacks.append(hydra.utils.instantiate(callback_config))
return callbacks
@@ -41,8 +45,8 @@ def configure_logger(config: DictConfig) -> List[Type[LightningLoggerBase]]:
logger = []
if config.get("logger"):
for logger_config in config.logger.values():
- if config.get("_target_"):
- log.info(f"Instantiating callback <{logger_config._target_}>")
+ if logger_config.get("_target_"):
+ log.info(f"Instantiating logger <{logger_config._target_}>")
logger.append(hydra.utils.instantiate(logger_config))
return logger
@@ -67,17 +71,8 @@ def extras(config: DictConfig) -> None:
# Debuggers do not like GPUs and multiprocessing.
if config.trainer.get("gpus"):
config.trainer.gpus = 0
- if config.datamodule.get("pin_memory"):
- config.datamodule.pin_memory = False
- if config.datamodule.get("num_workers"):
- config.datamodule.num_workers = 0
-
- # Force multi-gpu friendly config.
- accelerator = config.trainer.get("accelerator")
- if accelerator in ["ddp", "ddp_spawn", "dp", "ddp2"]:
- log.info(
- f"Forcing ddp friendly configuration! <config.trainer.accelerator={accelerator}>"
- )
+ if config.trainer.get("precision"):
+ config.trainer.precision = 32
if config.datamodule.get("pin_memory"):
config.datamodule.pin_memory = False
if config.datamodule.get("num_workers"):