diff options
Diffstat (limited to 'notebooks/05c-test-model-end-to-end.ipynb')
-rw-r--r-- | notebooks/05c-test-model-end-to-end.ipynb | 448 |
1 files changed, 426 insertions, 22 deletions
diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb index b652bdd..e3e92e2 100644 --- a/notebooks/05c-test-model-end-to-end.ipynb +++ b/notebooks/05c-test-model-end-to-end.ipynb @@ -2,10 +2,19 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 4, "id": "1e40a88b", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", @@ -25,7 +34,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 5, "id": "d3a6146b-94b1-4618-a4e4-00f8e23ffdb0", "metadata": {}, "outputs": [], @@ -45,32 +54,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "mapping:\n", - " _target_: text_recognizer.data.mappings.WordPieceMapping\n", - " num_features: 1000\n", - " tokens: iamdb_1kwp_tokens_1000.txt\n", - " lexicon: iamdb_1kwp_lex_1000.txt\n", - " data_dir: null\n", - " use_words: false\n", - " prepend_wordsep: false\n", - " special_tokens:\n", - " - <s>\n", - " - <e>\n", - " - <p>\n", - " extra_symbols:\n", - " - \\n\n", "_target_: text_recognizer.models.transformer.TransformerLitModel\n", "interval: step\n", "monitor: val/loss\n", - "ignore_tokens:\n", - "- <s>\n", - "- <e>\n", - "- <p>\n", "start_token: <s>\n", "end_token: <e>\n", "pad_token: <p>\n", "\n", - "{'mapping': {'_target_': 'text_recognizer.data.mappings.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\\\n']}, '_target_': 'text_recognizer.models.transformer.TransformerLitModel', 'interval': 'step', 'monitor': 'val/loss', 'ignore_tokens': ['<s>', '<e>', '<p>'], 'start_token': '<s>', 'end_token': '<e>', 'pad_token': '<p>'}\n" + "{'_target_': 'text_recognizer.models.transformer.TransformerLitModel', 'interval': 'step', 'monitor': 'val/loss', 'start_token': '<s>', 'end_token': '<e>', 'pad_token': '<p>'}\n" ] } ], @@ -85,6 +76,20 @@ { "cell_type": "code", "execution_count": null, + "id": "5e6b49ce-7685-4491-bd0a-51487f06a237", + "metadata": {}, + "outputs": [], + "source": [ + "# context initialization\n", + "with initialize(config_path=\"../training/conf/mapping/\", job_name=\"test_app\"):\n", + " cfg = compose(config_name=\"word_piece\")\n", + " print(OmegaConf.to_yaml(cfg))\n", + " print(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, "id": "9c797159-845e-42c6-bd65-1c976ad627cd", "metadata": {}, "outputs": [], @@ -98,6 +103,405 @@ }, { "cell_type": "code", + "execution_count": 6, + "id": "764c8736-7d68-4261-a57d-face10ebbf42", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "callbacks:\n", + " model_checkpoint:\n", + " _target_: pytorch_lightning.callbacks.ModelCheckpoint\n", + " monitor: val/loss\n", + " save_top_k: 1\n", + " save_last: true\n", + " mode: min\n", + " verbose: false\n", + " dirpath: checkpoints/\n", + " filename:\n", + " epoch:02d: null\n", + " learning_rate_monitor:\n", + " _target_: pytorch_lightning.callbacks.LearningRateMonitor\n", + " logging_interval: step\n", + " log_momentum: false\n", + " watch_model:\n", + " _target_: callbacks.wandb_callbacks.WatchModel\n", + " log: all\n", + " log_freq: 100\n", + " upload_code_as_artifact:\n", + " _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact\n", + " project_dir: ${work_dir}/text_recognizer\n", + " upload_ckpts_as_artifact:\n", + " _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n", + " ckpt_dir: checkpoints/\n", + " upload_best_only: true\n", + " log_text_predictions:\n", + " _target_: callbacks.wandb_callbacks.LogTextPredictions\n", + " num_samples: 8\n", + "criterion:\n", + " _target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss\n", + " smoothing: 0.1\n", + " ignore_index: 1002\n", + "datamodule:\n", + " _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs\n", + " batch_size: 8\n", + " num_workers: 12\n", + " train_fraction: 0.8\n", + " augment: true\n", + " pin_memory: false\n", + "logger:\n", + " wandb:\n", + " _target_: pytorch_lightning.loggers.wandb.WandbLogger\n", + " project: text-recognizer\n", + " name: null\n", + " save_dir: .\n", + " offline: false\n", + " id: null\n", + " log_model: false\n", + " prefix: ''\n", + " job_type: train\n", + " group: ''\n", + " tags: []\n", + "lr_scheduler:\n", + " _target_: torch.optim.lr_scheduler.OneCycleLR\n", + " max_lr: 0.001\n", + " total_steps: null\n", + " epochs: 512\n", + " steps_per_epoch: 4992\n", + " pct_start: 0.3\n", + " anneal_strategy: cos\n", + " cycle_momentum: true\n", + " base_momentum: 0.85\n", + " max_momentum: 0.95\n", + " div_factor: 25.0\n", + " final_div_factor: 10000.0\n", + " three_phase: true\n", + " last_epoch: -1\n", + " verbose: false\n", + "mapping:\n", + " _target_: text_recognizer.data.word_piece_mapping.WordPieceMapping\n", + " num_features: 1000\n", + " tokens: iamdb_1kwp_tokens_1000.txt\n", + " lexicon: iamdb_1kwp_lex_1000.txt\n", + " data_dir: null\n", + " use_words: false\n", + " prepend_wordsep: false\n", + " special_tokens:\n", + " - <s>\n", + " - <e>\n", + " - <p>\n", + " extra_symbols:\n", + " - '\n", + "\n", + " '\n", + "model:\n", + " _target_: text_recognizer.models.transformer.TransformerLitModel\n", + " interval: step\n", + " monitor: val/loss\n", + " max_output_len: 451\n", + " start_token: <s>\n", + " end_token: <e>\n", + " pad_token: <p>\n", + "network:\n", + " encoder:\n", + " _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet\n", + " arch: b0\n", + " out_channels: 1280\n", + " stochastic_dropout_rate: 0.2\n", + " bn_momentum: 0.99\n", + " bn_eps: 0.001\n", + " decoder:\n", + " _target_: text_recognizer.networks.transformer.Decoder\n", + " dim: 96\n", + " depth: 2\n", + " num_heads: 8\n", + " attn_fn: text_recognizer.networks.transformer.attention.Attention\n", + " attn_kwargs:\n", + " dim_head: 16\n", + " dropout_rate: 0.2\n", + " norm_fn: torch.nn.LayerNorm\n", + " ff_fn: text_recognizer.networks.transformer.mlp.FeedForward\n", + " ff_kwargs:\n", + " dim_out: null\n", + " expansion_factor: 4\n", + " glu: true\n", + " dropout_rate: 0.2\n", + " cross_attend: true\n", + " pre_norm: true\n", + " rotary_emb: null\n", + " _target_: text_recognizer.networks.conv_transformer.ConvTransformer\n", + " input_dims:\n", + " - 1\n", + " - 576\n", + " - 640\n", + " hidden_dim: 96\n", + " dropout_rate: 0.2\n", + " num_classes: 1006\n", + " pad_index: 1002\n", + "optimizer:\n", + " _target_: madgrad.MADGRAD\n", + " lr: 0.001\n", + " momentum: 0.9\n", + " weight_decay: 0\n", + " eps: 1.0e-06\n", + "trainer:\n", + " _target_: pytorch_lightning.Trainer\n", + " stochastic_weight_avg: false\n", + " auto_scale_batch_size: binsearch\n", + " auto_lr_find: false\n", + " gradient_clip_val: 0\n", + " fast_dev_run: false\n", + " gpus: 1\n", + " precision: 16\n", + " max_epochs: 512\n", + " terminate_on_nan: true\n", + " weights_summary: top\n", + " limit_train_batches: 1.0\n", + " limit_val_batches: 1.0\n", + " limit_test_batches: 1.0\n", + " resume_from_checkpoint: null\n", + "seed: 4711\n", + "tune: false\n", + "train: true\n", + "test: true\n", + "logging: INFO\n", + "debug: false\n", + "\n", + "{'callbacks': {'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': {'epoch:02d': None}}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_code_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCodeAsArtifact', 'project_dir': '${work_dir}/text_recognizer'}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_text_predictions': {'_target_': 'callbacks.wandb_callbacks.LogTextPredictions', 'num_samples': 8}}, 'criterion': {'_target_': 'text_recognizer.criterions.label_smoothing.LabelSmoothingLoss', 'smoothing': 0.1, 'ignore_index': 1002}, 'datamodule': {'_target_': 'text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs', 'batch_size': 8, 'num_workers': 12, 'train_fraction': 0.8, 'augment': True, 'pin_memory': False}, 'logger': {'wandb': {'_target_': 'pytorch_lightning.loggers.wandb.WandbLogger', 'project': 'text-recognizer', 'name': None, 'save_dir': '.', 'offline': False, 'id': None, 'log_model': False, 'prefix': '', 'job_type': 'train', 'group': '', 'tags': []}}, 'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.OneCycleLR', 'max_lr': 0.001, 'total_steps': None, 'epochs': 512, 'steps_per_epoch': 4992, 'pct_start': 0.3, 'anneal_strategy': 'cos', 'cycle_momentum': True, 'base_momentum': 0.85, 'max_momentum': 0.95, 'div_factor': 25.0, 'final_div_factor': 10000.0, 'three_phase': True, 'last_epoch': -1, 'verbose': False}, 'mapping': {'_target_': 'text_recognizer.data.word_piece_mapping.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}, 'model': {'_target_': 'text_recognizer.models.transformer.TransformerLitModel', 'interval': 'step', 'monitor': 'val/loss', 'max_output_len': 451, 'start_token': '<s>', 'end_token': '<e>', 'pad_token': '<p>'}, 'network': {'encoder': {'_target_': 'text_recognizer.networks.encoders.efficientnet.EfficientNet', 'arch': 'b0', 'out_channels': 1280, 'stochastic_dropout_rate': 0.2, 'bn_momentum': 0.99, 'bn_eps': 0.001}, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'dim': 96, 'depth': 2, 'num_heads': 8, 'attn_fn': 'text_recognizer.networks.transformer.attention.Attention', 'attn_kwargs': {'dim_head': 16, 'dropout_rate': 0.2}, 'norm_fn': 'torch.nn.LayerNorm', 'ff_fn': 'text_recognizer.networks.transformer.mlp.FeedForward', 'ff_kwargs': {'dim_out': None, 'expansion_factor': 4, 'glu': True, 'dropout_rate': 0.2}, 'cross_attend': True, 'pre_norm': True, 'rotary_emb': None}, '_target_': 'text_recognizer.networks.conv_transformer.ConvTransformer', 'input_dims': [1, 576, 640], 'hidden_dim': 96, 'dropout_rate': 0.2, 'num_classes': 1006, 'pad_index': 1002}, 'optimizer': {'_target_': 'madgrad.MADGRAD', 'lr': 0.001, 'momentum': 0.9, 'weight_decay': 0, 'eps': 1e-06}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'stochastic_weight_avg': False, 'auto_scale_batch_size': 'binsearch', 'auto_lr_find': False, 'gradient_clip_val': 0, 'fast_dev_run': False, 'gpus': 1, 'precision': 16, 'max_epochs': 512, 'terminate_on_nan': True, 'weights_summary': 'top', 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'limit_test_batches': 1.0, 'resume_from_checkpoint': None}, 'seed': 4711, 'tune': False, 'train': True, 'test': True, 'logging': 'INFO', 'debug': False}\n" + ] + } + ], + "source": [ + "# context initialization\n", + "with initialize(config_path=\"../training/conf/\", job_name=\"test_app\"):\n", + " cfg = compose(config_name=\"config\")\n", + " print(OmegaConf.to_yaml(cfg))\n", + " print(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9382f0ab-8760-4d59-b0b5-b8b65dd1ea31", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': {'epoch:02d': None}}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_code_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCodeAsArtifact', 'project_dir': '${work_dir}/text_recognizer'}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_text_predictions': {'_target_': 'callbacks.wandb_callbacks.LogTextPredictions', 'num_samples': 8}}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cfg.get(\"callbacks\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "216d5680-66bf-4190-9401-1a59dbbc43af", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pytorch_lightning.callbacks.ModelCheckpoint\n", + "pytorch_lightning.callbacks.LearningRateMonitor\n", + "callbacks.wandb_callbacks.WatchModel\n", + "callbacks.wandb_callbacks.UploadCodeAsArtifact\n", + "callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n", + "callbacks.wandb_callbacks.LogTextPredictions\n" + ] + } + ], + "source": [ + "for l in cfg.callbacks.values():\n", + " print(l.get(\"_target_\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c1a9aa6b-6405-4ffe-b065-02340762476a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2021-08-03 15:27:02.069 | DEBUG | text_recognizer.data.word_piece_mapping:__init__:37 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n" + ] + } + ], + "source": [ + "mapping = instantiate(cfg.mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "969ba3be-d78f-4b1e-b522-ea8a42669e86", + "metadata": {}, + "outputs": [], + "source": [ + "network = instantiate(cfg.network)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a23893a9-a0da-4327-a617-dc0c2011e5e8", + "metadata": {}, + "outputs": [], + "source": [ + "OmegaConf.set_struct(cfg, False)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a6fae1fa-492d-4648-80fd-1c0dac659b02", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "datamodule = instantiate(cfg.datamodule, mapping=mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "514053ef-fcac-4f3c-a7c8-72c6927d6798", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2021-08-03 15:28:22.541 | INFO | text_recognizer.data.iam_paragraphs:setup:95 - Loading IAM paragraph regions and lines for None...\n", + "2021-08-03 15:28:45.280 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:68 - IAM Synthetic dataset steup for stage None...\n" + ] + } + ], + "source": [ + "datamodule.prepare_data()\n", + "datamodule.setup()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "4bad950b-a197-4c60-ad89-903124659a98", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4992" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(datamodule.train_dataloader())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7db05cbd-48b3-43fa-a99a-353126311879", + "metadata": {}, + "outputs": [], + "source": [ + "mapping" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f6e01c15-9a1b-4036-87ae-78716c592264", + "metadata": {}, + "outputs": [], + "source": [ + "config = cfg" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "4dc475fc-31f4-487e-88c8-b0f445131f5b", + "metadata": {}, + "outputs": [], + "source": [ + "loss_fn = instantiate(cfg.criterion)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c5c8ed64-d98c-47b5-baf2-1ba57a6c882f", + "metadata": {}, + "outputs": [], + "source": [ + "import hydra" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "b5ff5b24-f804-402b-a8ab-f366443025ca", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + " model = hydra.utils.instantiate(\n", + " config.model,\n", + " mapping=mapping,\n", + " network=network,\n", + " loss_fn=loss_fn,\n", + " optimizer_config=config.optimizer,\n", + " lr_scheduler_config=config.lr_scheduler,\n", + " _recursive_=False,\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "99f8a39f-8b10-4f7d-8bff-52794fd48717", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<bound method WordPieceMapping.get_index of <text_recognizer.data.word_piece_mapping.WordPieceMapping object at 0x7fae3b489610>>" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mapping.get_index" + ] + }, + { + "cell_type": "code", "execution_count": null, "id": "af2c8cfa-0b45-4681-b671-0f97ace62516", "metadata": {}, |