diff options
Diffstat (limited to 'notebooks/03-look-at-iam-paragraphs.ipynb')
-rw-r--r-- | notebooks/03-look-at-iam-paragraphs.ipynb | 272 |
1 files changed, 142 insertions, 130 deletions
diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb index 0f805f6..f57d491 100644 --- a/notebooks/03-look-at-iam-paragraphs.ipynb +++ b/notebooks/03-look-at-iam-paragraphs.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "6ce2519f", "metadata": {}, "outputs": [], @@ -61,7 +61,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 35, "id": "e9386367-2b49-4633-9936-57081132e59e", "metadata": {}, "outputs": [ @@ -69,34 +69,51 @@ "name": "stdout", "output_type": "stream", "text": [ + "seed: 4711\n", + "tune: false\n", + "train: true\n", + "test: true\n", + "logging: INFO\n", + "work_dir: ${hydra:runtime.cwd}\n", + "debug: false\n", + "print_config: false\n", + "ignore_warnings: true\n", + "summary:\n", + "- - 1\n", + " - 1\n", + " - 576\n", + " - 640\n", + "- - 1\n", + " - 682\n", "callbacks:\n", - " model_checkpoint:\n", - " _target_: pytorch_lightning.callbacks.ModelCheckpoint\n", - " monitor: val/loss\n", - " save_top_k: 1\n", - " save_last: true\n", - " mode: min\n", - " verbose: false\n", - " dirpath: checkpoints/\n", - " filename: '{epoch:02d}'\n", - " learning_rate_monitor:\n", - " _target_: pytorch_lightning.callbacks.LearningRateMonitor\n", - " logging_interval: step\n", - " log_momentum: false\n", - " watch_model:\n", - " _target_: callbacks.wandb_callbacks.WatchModel\n", - " log: all\n", - " log_freq: 100\n", - " upload_code_as_artifact:\n", - " _target_: callbacks.wandb_callbacks.UploadConfigAsArtifact\n", - " upload_ckpts_as_artifact:\n", - " _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n", - " ckpt_dir: checkpoints/\n", - " upload_best_only: true\n", - " log_text_predictions:\n", - " _target_: callbacks.wandb_callbacks.LogTextPredictions\n", - " num_samples: 8\n", - " log_train: false\n", + " lightning:\n", + " model_checkpoint:\n", + " _target_: pytorch_lightning.callbacks.ModelCheckpoint\n", + " monitor: val/loss\n", + " save_top_k: 1\n", + " save_last: true\n", + " mode: min\n", + " verbose: false\n", + " dirpath: checkpoints/\n", + " filename: '{epoch:02d}'\n", + " learning_rate_monitor:\n", + " _target_: pytorch_lightning.callbacks.LearningRateMonitor\n", + " logging_interval: step\n", + " log_momentum: false\n", + " wandb:\n", + " watch_model:\n", + " _target_: callbacks.wandb_callbacks.WatchModel\n", + " log: all\n", + " log_freq: 100\n", + " upload_code_as_artifact:\n", + " _target_: callbacks.wandb_callbacks.UploadConfigAsArtifact\n", + " upload_ckpts_as_artifact:\n", + " _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n", + " ckpt_dir: checkpoints/\n", + " upload_best_only: true\n", + " log_text_predictions:\n", + " _target_: callbacks.wandb_callbacks.LogTextPredictions\n", + " num_samples: 8\n", " stochastic_weight_averaging:\n", " _target_: pytorch_lightning.callbacks.StochasticWeightAveraging\n", " swa_epoch_start: 0.75\n", @@ -104,6 +121,20 @@ " annealing_epochs: 10\n", " annealing_strategy: cos\n", " device: null\n", + "datamodule:\n", + " _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs\n", + " batch_size: 4\n", + " num_workers: 12\n", + " train_fraction: 0.8\n", + " pin_memory: true\n", + " transform: transform/paragraphs.yaml\n", + " test_transform: test_transform/paragraphs.yaml\n", + " mapping:\n", + " _target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping\n", + " extra_symbols:\n", + " - '\n", + "\n", + " '\n", "logger:\n", " wandb:\n", " _target_: pytorch_lightning.loggers.wandb.WandbLogger\n", @@ -122,36 +153,20 @@ " stochastic_weight_avg: true\n", " auto_scale_batch_size: binsearch\n", " auto_lr_find: false\n", - " gradient_clip_val: 0.5\n", + " gradient_clip_val: 0.0\n", " fast_dev_run: false\n", " gpus: 1\n", " precision: 16\n", - " max_epochs: 512\n", + " max_epochs: 1000\n", " terminate_on_nan: true\n", " weights_summary: null\n", " limit_train_batches: 1.0\n", " limit_val_batches: 1.0\n", " limit_test_batches: 1.0\n", " resume_from_checkpoint: null\n", - " accumulate_grad_batches: 32\n", + " accumulate_grad_batches: 16\n", " overfit_batches: 0\n", - "seed: 4711\n", - "tune: false\n", - "train: true\n", - "test: true\n", - "logging: INFO\n", - "work_dir: ${hydra:runtime.cwd}\n", - "debug: false\n", - "print_config: false\n", - "ignore_warnings: true\n", - "summary:\n", - "- - 1\n", - " - 1\n", - " - 576\n", - " - 640\n", - "- - 1\n", - " - 682\n", - "epochs: 512\n", + "epochs: 1000\n", "ignore_index: 3\n", "num_classes: 58\n", "max_output_len: 682\n", @@ -159,17 +174,18 @@ " _target_: torch.nn.CrossEntropyLoss\n", " ignore_index: 3\n", "mapping:\n", - " _target_: text_recognizer.data.emnist_mapping.EmnistMapping\n", - " extra_symbols:\n", - " - '\n", + " mapping:\n", + " _target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping\n", + " extra_symbols:\n", + " - '\n", "\n", - " '\n", + " '\n", "optimizers:\n", " madgrad:\n", " _target_: madgrad.MADGRAD\n", " lr: 0.0003\n", " momentum: 0.9\n", - " weight_decay: 0\n", + " weight_decay: 5.0e-06\n", " eps: 1.0e-06\n", " parameters: network\n", "lr_schedulers:\n", @@ -177,9 +193,9 @@ " _target_: torch.optim.lr_scheduler.OneCycleLR\n", " max_lr: 0.0003\n", " total_steps: null\n", - " epochs: 512\n", - " steps_per_epoch: 90\n", - " pct_start: 0.1\n", + " epochs: 1000\n", + " steps_per_epoch: 632\n", + " pct_start: 0.03\n", " anneal_strategy: cos\n", " cycle_momentum: true\n", " base_momentum: 0.85\n", @@ -191,21 +207,12 @@ " verbose: false\n", " interval: step\n", " monitor: val/loss\n", - "datamodule:\n", - " _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs\n", - " batch_size: 4\n", - " num_workers: 12\n", - " train_fraction: 0.8\n", - " augment: true\n", - " pin_memory: false\n", - " word_pieces: false\n", - " resize: null\n", "network:\n", " _target_: text_recognizer.networks.conv_transformer.ConvTransformer\n", " input_dims:\n", " - 1\n", - " - 56\n", - " - 1024\n", + " - 576\n", + " - 640\n", " hidden_dim: 128\n", " encoder_dim: 1280\n", " dropout_rate: 0.2\n", @@ -250,6 +257,12 @@ " dropout_rate: 0.2\n", " max_len: 682\n", "model:\n", + " mapping:\n", + " _target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping\n", + " extra_symbols:\n", + " - '\n", + "\n", + " '\n", " _target_: text_recognizer.models.transformer.TransformerLitModel\n", " max_output_len: 682\n", " start_token: <s>\n", @@ -257,26 +270,18 @@ " pad_token: <p>\n", "\n" ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/aktersnurra/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/hydra/_internal/defaults_list.py:251: UserWarning: In 'config': Defaults list is missing `_self_`. See https://hydra.cc/docs/upgrades/1.0_to_1.1/default_composition_order for more information\n", - " warnings.warn(msg, UserWarning)\n" - ] } ], "source": [ "# context initialization\n", - "with initialize(config_path=\"../training/conf/\", job_name=\"test_app\"):\n", + "with initialize(config_path=\"../training/conf/\"):\n", " cfg = compose(config_name=\"config\", overrides=[\"+experiment=cnn_transformer_paragraphs\"])\n", " print(OmegaConf.to_yaml(cfg))" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 36, "id": "1c4624d1-6de5-41ab-9208-0988fcdba76d", "metadata": {}, "outputs": [ @@ -284,8 +289,11 @@ "name": "stderr", "output_type": "stream", "text": [ - "2021-10-02 23:26:26.439 | INFO | text_recognizer.data.iam_paragraphs:setup:103 - Loading IAM paragraph regions and lines for None...\n", - "2021-10-02 23:26:49.529 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:68 - IAM Synthetic dataset steup for stage None...\n" + "2021-10-11 22:01:56.908 | INFO | text_recognizer.data.iam_paragraphs:setup:92 - Loading IAM paragraph regions and lines for None...\n", + "2021-10-11 22:02:16.771 | DEBUG | text_recognizer.data.transforms.load_transform:_load_config:17 - Loading transforms from config: transform/paragraphs.yaml\n", + "2021-10-11 22:02:19.953 | DEBUG | text_recognizer.data.transforms.load_transform:_load_config:17 - Loading transforms from config: test_transform/paragraphs.yaml\n", + "2021-10-11 22:02:19.957 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:67 - IAM Synthetic dataset steup for stage None...\n", + "2021-10-11 22:02:32.207 | DEBUG | text_recognizer.data.transforms.load_transform:_load_config:17 - Loading transforms from config: transform/paragraphs.yaml\n" ] }, { @@ -296,8 +304,8 @@ "Num classes: 58\n", "Dims: (1, 576, 640)\n", "Output dims: (682, 1)\n", - "Train/val/test sizes: 19981, 262, 231\n", - "Train Batch x stats: (torch.Size([4, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0134), tensor(0.0650), tensor(1.))\n", + "Train/val/test sizes: 19882, 262, 231\n", + "Train Batch x stats: (torch.Size([4, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0316), tensor(0.0974), tensor(1.))\n", "Train Batch y stats: (torch.Size([4, 682]), torch.int64, tensor(1), tensor(57))\n", "Test Batch x stats: (torch.Size([4, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0321), tensor(0.0744), tensor(0.8118))\n", "Test Batch y stats: (torch.Size([4, 682]), torch.int64, tensor(1), tensor(57))\n", @@ -306,7 +314,7 @@ } ], "source": [ - "datamodule = instantiate(cfg.datamodule, mapping=cfg.mapping)\n", + "datamodule = instantiate(cfg.datamodule)\n", "datamodule.prepare_data()\n", "datamodule.setup()\n", "print(datamodule)" @@ -466,112 +474,116 @@ }, { "cell_type": "code", - "execution_count": 4, - "id": "8e667fb0-3ff4-45c7-9de3-8f712407d12f", + "execution_count": 23, + "id": "6f53950a-6858-40b6-ad4a-2857227d9d59", "metadata": {}, "outputs": [], "source": [ - "import torch" + "a = torch.randn(2, 1, 576, 640), torch.randn(2, 1, 576, 640)" ] }, { "cell_type": "code", - "execution_count": 23, - "id": "c4b7e5c5-62a9-4415-954d-65d59ddd82c4", + "execution_count": 3, + "id": "580b9bbc-b213-4ef8-aefb-cfb453d08a44", + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.data.transforms.load_transform import load_transforms" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "15744542-0880-45f2-881b-d81e04668305", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "torch.Size([2, 92160])" + "Compose(\n", + " <text_recognizer.data.transforms.embed_crop.EmbedCrop object at 0x7f830edf5df0>\n", + " ColorJitter(brightness=[0.8, 1.6], contrast=None, saturation=None, hue=None)\n", + " RandomAffine(degrees=[-1.0, 1.0], shear=[-30.0, 20.0], interpolation=bilinear)\n", + " ToTensor()\n", + ")" ] }, - "execution_count": 23, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "torch.randn(2, 256, 18, 20).flatten(start_dim=1).shape" + "load_transforms(\"iam_lines.yaml\")" ] }, { "cell_type": "code", - "execution_count": 14, - "id": "bf820d79-2ce1-4c15-b006-6cb4407ccbba", + "execution_count": 16, + "id": "5d7eab42-c407-4b88-9492-e9279a38232a", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[autoreload of text_recognizer.networks.barlow_twins.projector failed: Traceback (most recent call last):\n", - " File \"/home/aktersnurra/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 245, in check\n", - " superreload(m, reload, self.old_objects)\n", - " File \"/home/aktersnurra/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 410, in superreload\n", - " update_generic(old_obj, new_obj)\n", - " File \"/home/aktersnurra/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 347, in update_generic\n", - " update(a, b)\n", - " File \"/home/aktersnurra/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 302, in update_class\n", - " if update_generic(old_obj, new_obj): continue\n", - " File \"/home/aktersnurra/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 347, in update_generic\n", - " update(a, b)\n", - " File \"/home/aktersnurra/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 266, in update_function\n", - " setattr(old, name, getattr(new, name))\n", - "ValueError: __init__() requires a code object with 0 free vars, not 1\n", - "]\n" - ] - } - ], + "outputs": [], "source": [ - "from text_recognizer.networks.barlow_twins.projector import Projector" + "from torchvision.transforms import ColorJitter" ] }, { "cell_type": "code", - "execution_count": 24, - "id": "cda38d24-b69a-4033-b07c-2e7558889cf7", + "execution_count": 17, + "id": "3d02a3fe-1128-416f-80e1-84c9287e613d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "ColorJitter(brightness=[0.5, 1.0], contrast=None, saturation=None, hue=None)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "net = Projector([92160, 4096, 256])" + "ColorJitter(brightness=[0.5, 1.0])" ] }, { "cell_type": "code", - "execution_count": 25, - "id": "ed706880-8556-418d-85a6-616ac16a3334", + "execution_count": 45, + "id": "f0ead6d1-3093-4a42-a3b2-b3cdea75fc21", "metadata": {}, "outputs": [], "source": [ - "z = torch.randn(2, 256, 18, 20).flatten(start_dim=1)\n" + "import torchvision" ] }, { "cell_type": "code", - "execution_count": 26, - "id": "58e6fddc-81be-44c1-bbe8-05401aa05287", + "execution_count": 46, + "id": "f4c1606d-a063-465d-bf22-61a1cbc14ab9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "torch.Size([2, 256])" + "<InterpolationMode.BILINEAR: 'bilinear'>" ] }, - "execution_count": 26, + "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "net(z).shape" + "getattr(torchvision.transforms.functional.InterpolationMode, \"BILINEAR\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "6f53950a-6858-40b6-ad4a-2857227d9d59", + "id": "617568a7-fde1-4f60-80c5-922d764f0c52", "metadata": {}, "outputs": [], "source": [] |