summaryrefslogtreecommitdiff
path: root/notebooks/03-look-at-iam-paragraphs.ipynb
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-24 00:55:29 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-24 00:55:29 +0200
commit1062298ac5c03ea3d80cc5d9482ee942832665b6 (patch)
treeec147297a956d6e2b2dfbe981aa8a76ea56d61d9 /notebooks/03-look-at-iam-paragraphs.ipynb
parentca1925433861f6b1037bcd81112d56717d9f153b (diff)
Update notebooks
Diffstat (limited to 'notebooks/03-look-at-iam-paragraphs.ipynb')
-rw-r--r--notebooks/03-look-at-iam-paragraphs.ipynb272
1 files changed, 142 insertions, 130 deletions
diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb
index 0f805f6..f57d491 100644
--- a/notebooks/03-look-at-iam-paragraphs.ipynb
+++ b/notebooks/03-look-at-iam-paragraphs.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 1,
"id": "6ce2519f",
"metadata": {},
"outputs": [],
@@ -61,7 +61,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 35,
"id": "e9386367-2b49-4633-9936-57081132e59e",
"metadata": {},
"outputs": [
@@ -69,34 +69,51 @@
"name": "stdout",
"output_type": "stream",
"text": [
+ "seed: 4711\n",
+ "tune: false\n",
+ "train: true\n",
+ "test: true\n",
+ "logging: INFO\n",
+ "work_dir: ${hydra:runtime.cwd}\n",
+ "debug: false\n",
+ "print_config: false\n",
+ "ignore_warnings: true\n",
+ "summary:\n",
+ "- - 1\n",
+ " - 1\n",
+ " - 576\n",
+ " - 640\n",
+ "- - 1\n",
+ " - 682\n",
"callbacks:\n",
- " model_checkpoint:\n",
- " _target_: pytorch_lightning.callbacks.ModelCheckpoint\n",
- " monitor: val/loss\n",
- " save_top_k: 1\n",
- " save_last: true\n",
- " mode: min\n",
- " verbose: false\n",
- " dirpath: checkpoints/\n",
- " filename: '{epoch:02d}'\n",
- " learning_rate_monitor:\n",
- " _target_: pytorch_lightning.callbacks.LearningRateMonitor\n",
- " logging_interval: step\n",
- " log_momentum: false\n",
- " watch_model:\n",
- " _target_: callbacks.wandb_callbacks.WatchModel\n",
- " log: all\n",
- " log_freq: 100\n",
- " upload_code_as_artifact:\n",
- " _target_: callbacks.wandb_callbacks.UploadConfigAsArtifact\n",
- " upload_ckpts_as_artifact:\n",
- " _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n",
- " ckpt_dir: checkpoints/\n",
- " upload_best_only: true\n",
- " log_text_predictions:\n",
- " _target_: callbacks.wandb_callbacks.LogTextPredictions\n",
- " num_samples: 8\n",
- " log_train: false\n",
+ " lightning:\n",
+ " model_checkpoint:\n",
+ " _target_: pytorch_lightning.callbacks.ModelCheckpoint\n",
+ " monitor: val/loss\n",
+ " save_top_k: 1\n",
+ " save_last: true\n",
+ " mode: min\n",
+ " verbose: false\n",
+ " dirpath: checkpoints/\n",
+ " filename: '{epoch:02d}'\n",
+ " learning_rate_monitor:\n",
+ " _target_: pytorch_lightning.callbacks.LearningRateMonitor\n",
+ " logging_interval: step\n",
+ " log_momentum: false\n",
+ " wandb:\n",
+ " watch_model:\n",
+ " _target_: callbacks.wandb_callbacks.WatchModel\n",
+ " log: all\n",
+ " log_freq: 100\n",
+ " upload_code_as_artifact:\n",
+ " _target_: callbacks.wandb_callbacks.UploadConfigAsArtifact\n",
+ " upload_ckpts_as_artifact:\n",
+ " _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n",
+ " ckpt_dir: checkpoints/\n",
+ " upload_best_only: true\n",
+ " log_text_predictions:\n",
+ " _target_: callbacks.wandb_callbacks.LogTextPredictions\n",
+ " num_samples: 8\n",
" stochastic_weight_averaging:\n",
" _target_: pytorch_lightning.callbacks.StochasticWeightAveraging\n",
" swa_epoch_start: 0.75\n",
@@ -104,6 +121,20 @@
" annealing_epochs: 10\n",
" annealing_strategy: cos\n",
" device: null\n",
+ "datamodule:\n",
+ " _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs\n",
+ " batch_size: 4\n",
+ " num_workers: 12\n",
+ " train_fraction: 0.8\n",
+ " pin_memory: true\n",
+ " transform: transform/paragraphs.yaml\n",
+ " test_transform: test_transform/paragraphs.yaml\n",
+ " mapping:\n",
+ " _target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping\n",
+ " extra_symbols:\n",
+ " - '\n",
+ "\n",
+ " '\n",
"logger:\n",
" wandb:\n",
" _target_: pytorch_lightning.loggers.wandb.WandbLogger\n",
@@ -122,36 +153,20 @@
" stochastic_weight_avg: true\n",
" auto_scale_batch_size: binsearch\n",
" auto_lr_find: false\n",
- " gradient_clip_val: 0.5\n",
+ " gradient_clip_val: 0.0\n",
" fast_dev_run: false\n",
" gpus: 1\n",
" precision: 16\n",
- " max_epochs: 512\n",
+ " max_epochs: 1000\n",
" terminate_on_nan: true\n",
" weights_summary: null\n",
" limit_train_batches: 1.0\n",
" limit_val_batches: 1.0\n",
" limit_test_batches: 1.0\n",
" resume_from_checkpoint: null\n",
- " accumulate_grad_batches: 32\n",
+ " accumulate_grad_batches: 16\n",
" overfit_batches: 0\n",
- "seed: 4711\n",
- "tune: false\n",
- "train: true\n",
- "test: true\n",
- "logging: INFO\n",
- "work_dir: ${hydra:runtime.cwd}\n",
- "debug: false\n",
- "print_config: false\n",
- "ignore_warnings: true\n",
- "summary:\n",
- "- - 1\n",
- " - 1\n",
- " - 576\n",
- " - 640\n",
- "- - 1\n",
- " - 682\n",
- "epochs: 512\n",
+ "epochs: 1000\n",
"ignore_index: 3\n",
"num_classes: 58\n",
"max_output_len: 682\n",
@@ -159,17 +174,18 @@
" _target_: torch.nn.CrossEntropyLoss\n",
" ignore_index: 3\n",
"mapping:\n",
- " _target_: text_recognizer.data.emnist_mapping.EmnistMapping\n",
- " extra_symbols:\n",
- " - '\n",
+ " mapping:\n",
+ " _target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping\n",
+ " extra_symbols:\n",
+ " - '\n",
"\n",
- " '\n",
+ " '\n",
"optimizers:\n",
" madgrad:\n",
" _target_: madgrad.MADGRAD\n",
" lr: 0.0003\n",
" momentum: 0.9\n",
- " weight_decay: 0\n",
+ " weight_decay: 5.0e-06\n",
" eps: 1.0e-06\n",
" parameters: network\n",
"lr_schedulers:\n",
@@ -177,9 +193,9 @@
" _target_: torch.optim.lr_scheduler.OneCycleLR\n",
" max_lr: 0.0003\n",
" total_steps: null\n",
- " epochs: 512\n",
- " steps_per_epoch: 90\n",
- " pct_start: 0.1\n",
+ " epochs: 1000\n",
+ " steps_per_epoch: 632\n",
+ " pct_start: 0.03\n",
" anneal_strategy: cos\n",
" cycle_momentum: true\n",
" base_momentum: 0.85\n",
@@ -191,21 +207,12 @@
" verbose: false\n",
" interval: step\n",
" monitor: val/loss\n",
- "datamodule:\n",
- " _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs\n",
- " batch_size: 4\n",
- " num_workers: 12\n",
- " train_fraction: 0.8\n",
- " augment: true\n",
- " pin_memory: false\n",
- " word_pieces: false\n",
- " resize: null\n",
"network:\n",
" _target_: text_recognizer.networks.conv_transformer.ConvTransformer\n",
" input_dims:\n",
" - 1\n",
- " - 56\n",
- " - 1024\n",
+ " - 576\n",
+ " - 640\n",
" hidden_dim: 128\n",
" encoder_dim: 1280\n",
" dropout_rate: 0.2\n",
@@ -250,6 +257,12 @@
" dropout_rate: 0.2\n",
" max_len: 682\n",
"model:\n",
+ " mapping:\n",
+ " _target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping\n",
+ " extra_symbols:\n",
+ " - '\n",
+ "\n",
+ " '\n",
" _target_: text_recognizer.models.transformer.TransformerLitModel\n",
" max_output_len: 682\n",
" start_token: <s>\n",
@@ -257,26 +270,18 @@
" pad_token: <p>\n",
"\n"
]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/aktersnurra/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/hydra/_internal/defaults_list.py:251: UserWarning: In 'config': Defaults list is missing `_self_`. See https://hydra.cc/docs/upgrades/1.0_to_1.1/default_composition_order for more information\n",
- " warnings.warn(msg, UserWarning)\n"
- ]
}
],
"source": [
"# context initialization\n",
- "with initialize(config_path=\"../training/conf/\", job_name=\"test_app\"):\n",
+ "with initialize(config_path=\"../training/conf/\"):\n",
" cfg = compose(config_name=\"config\", overrides=[\"+experiment=cnn_transformer_paragraphs\"])\n",
" print(OmegaConf.to_yaml(cfg))"
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 36,
"id": "1c4624d1-6de5-41ab-9208-0988fcdba76d",
"metadata": {},
"outputs": [
@@ -284,8 +289,11 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2021-10-02 23:26:26.439 | INFO | text_recognizer.data.iam_paragraphs:setup:103 - Loading IAM paragraph regions and lines for None...\n",
- "2021-10-02 23:26:49.529 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:68 - IAM Synthetic dataset steup for stage None...\n"
+ "2021-10-11 22:01:56.908 | INFO | text_recognizer.data.iam_paragraphs:setup:92 - Loading IAM paragraph regions and lines for None...\n",
+ "2021-10-11 22:02:16.771 | DEBUG | text_recognizer.data.transforms.load_transform:_load_config:17 - Loading transforms from config: transform/paragraphs.yaml\n",
+ "2021-10-11 22:02:19.953 | DEBUG | text_recognizer.data.transforms.load_transform:_load_config:17 - Loading transforms from config: test_transform/paragraphs.yaml\n",
+ "2021-10-11 22:02:19.957 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:67 - IAM Synthetic dataset steup for stage None...\n",
+ "2021-10-11 22:02:32.207 | DEBUG | text_recognizer.data.transforms.load_transform:_load_config:17 - Loading transforms from config: transform/paragraphs.yaml\n"
]
},
{
@@ -296,8 +304,8 @@
"Num classes: 58\n",
"Dims: (1, 576, 640)\n",
"Output dims: (682, 1)\n",
- "Train/val/test sizes: 19981, 262, 231\n",
- "Train Batch x stats: (torch.Size([4, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0134), tensor(0.0650), tensor(1.))\n",
+ "Train/val/test sizes: 19882, 262, 231\n",
+ "Train Batch x stats: (torch.Size([4, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0316), tensor(0.0974), tensor(1.))\n",
"Train Batch y stats: (torch.Size([4, 682]), torch.int64, tensor(1), tensor(57))\n",
"Test Batch x stats: (torch.Size([4, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0321), tensor(0.0744), tensor(0.8118))\n",
"Test Batch y stats: (torch.Size([4, 682]), torch.int64, tensor(1), tensor(57))\n",
@@ -306,7 +314,7 @@
}
],
"source": [
- "datamodule = instantiate(cfg.datamodule, mapping=cfg.mapping)\n",
+ "datamodule = instantiate(cfg.datamodule)\n",
"datamodule.prepare_data()\n",
"datamodule.setup()\n",
"print(datamodule)"
@@ -466,112 +474,116 @@
},
{
"cell_type": "code",
- "execution_count": 4,
- "id": "8e667fb0-3ff4-45c7-9de3-8f712407d12f",
+ "execution_count": 23,
+ "id": "6f53950a-6858-40b6-ad4a-2857227d9d59",
"metadata": {},
"outputs": [],
"source": [
- "import torch"
+ "a = torch.randn(2, 1, 576, 640), torch.randn(2, 1, 576, 640)"
]
},
{
"cell_type": "code",
- "execution_count": 23,
- "id": "c4b7e5c5-62a9-4415-954d-65d59ddd82c4",
+ "execution_count": 3,
+ "id": "580b9bbc-b213-4ef8-aefb-cfb453d08a44",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.data.transforms.load_transform import load_transforms"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "15744542-0880-45f2-881b-d81e04668305",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "torch.Size([2, 92160])"
+ "Compose(\n",
+ " <text_recognizer.data.transforms.embed_crop.EmbedCrop object at 0x7f830edf5df0>\n",
+ " ColorJitter(brightness=[0.8, 1.6], contrast=None, saturation=None, hue=None)\n",
+ " RandomAffine(degrees=[-1.0, 1.0], shear=[-30.0, 20.0], interpolation=bilinear)\n",
+ " ToTensor()\n",
+ ")"
]
},
- "execution_count": 23,
+ "execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "torch.randn(2, 256, 18, 20).flatten(start_dim=1).shape"
+ "load_transforms(\"iam_lines.yaml\")"
]
},
{
"cell_type": "code",
- "execution_count": 14,
- "id": "bf820d79-2ce1-4c15-b006-6cb4407ccbba",
+ "execution_count": 16,
+ "id": "5d7eab42-c407-4b88-9492-e9279a38232a",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[autoreload of text_recognizer.networks.barlow_twins.projector failed: Traceback (most recent call last):\n",
- " File \"/home/aktersnurra/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 245, in check\n",
- " superreload(m, reload, self.old_objects)\n",
- " File \"/home/aktersnurra/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 410, in superreload\n",
- " update_generic(old_obj, new_obj)\n",
- " File \"/home/aktersnurra/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 347, in update_generic\n",
- " update(a, b)\n",
- " File \"/home/aktersnurra/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 302, in update_class\n",
- " if update_generic(old_obj, new_obj): continue\n",
- " File \"/home/aktersnurra/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 347, in update_generic\n",
- " update(a, b)\n",
- " File \"/home/aktersnurra/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 266, in update_function\n",
- " setattr(old, name, getattr(new, name))\n",
- "ValueError: __init__() requires a code object with 0 free vars, not 1\n",
- "]\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "from text_recognizer.networks.barlow_twins.projector import Projector"
+ "from torchvision.transforms import ColorJitter"
]
},
{
"cell_type": "code",
- "execution_count": 24,
- "id": "cda38d24-b69a-4033-b07c-2e7558889cf7",
+ "execution_count": 17,
+ "id": "3d02a3fe-1128-416f-80e1-84c9287e613d",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "ColorJitter(brightness=[0.5, 1.0], contrast=None, saturation=None, hue=None)"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "net = Projector([92160, 4096, 256])"
+ "ColorJitter(brightness=[0.5, 1.0])"
]
},
{
"cell_type": "code",
- "execution_count": 25,
- "id": "ed706880-8556-418d-85a6-616ac16a3334",
+ "execution_count": 45,
+ "id": "f0ead6d1-3093-4a42-a3b2-b3cdea75fc21",
"metadata": {},
"outputs": [],
"source": [
- "z = torch.randn(2, 256, 18, 20).flatten(start_dim=1)\n"
+ "import torchvision"
]
},
{
"cell_type": "code",
- "execution_count": 26,
- "id": "58e6fddc-81be-44c1-bbe8-05401aa05287",
+ "execution_count": 46,
+ "id": "f4c1606d-a063-465d-bf22-61a1cbc14ab9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "torch.Size([2, 256])"
+ "<InterpolationMode.BILINEAR: 'bilinear'>"
]
},
- "execution_count": 26,
+ "execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "net(z).shape"
+ "getattr(torchvision.transforms.functional.InterpolationMode, \"BILINEAR\")"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "6f53950a-6858-40b6-ad4a-2857227d9d59",
+ "id": "617568a7-fde1-4f60-80c5-922d764f0c52",
"metadata": {},
"outputs": [],
"source": []