summaryrefslogtreecommitdiff
path: root/notebooks
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks')
-rw-r--r--notebooks/05c-test-model-end-to-end.ipynb84
1 files changed, 74 insertions, 10 deletions
diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb
index 850d205..913eafd 100644
--- a/notebooks/05c-test-model-end-to-end.ipynb
+++ b/notebooks/05c-test-model-end-to-end.ipynb
@@ -133,6 +133,7 @@
" _target_: text_recognizer.models.vqvae.VQVAELitModel\n",
" interval: step\n",
" monitor: val/loss\n",
+ " latent_loss_weight: 0.25\n",
"network:\n",
" _target_: text_recognizer.networks.vqvae.VQVAE\n",
" in_channels: 1\n",
@@ -174,7 +175,7 @@
"print_config: true\n",
"ignore_warnings: true\n",
"\n",
- "{'callbacks': {'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': '{epoch:02d}'}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_code_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCodeAsArtifact', 'project_dir': '${work_dir}/text_recognizer'}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_image_reconstruction': {'_target_': 'callbacks.wandb_callbacks.LogReconstuctedImages', 'num_samples': 8}}, 'criterion': {'_target_': 'torch.nn.MSELoss', 'reduction': 'mean'}, 'datamodule': {'_target_': 'text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs', 'batch_size': 32, 'num_workers': 12, 'train_fraction': 0.8, 'augment': True, 'pin_memory': False, 'word_pieces': True}, 'logger': {'wandb': {'_target_': 'pytorch_lightning.loggers.wandb.WandbLogger', 'project': 'text-recognizer', 'name': None, 'save_dir': '.', 'offline': False, 'id': None, 'log_model': False, 'prefix': '', 'job_type': 'train', 'group': '', 'tags': []}}, 'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.OneCycleLR', 'max_lr': 0.001, 'total_steps': None, 'epochs': 64, 'steps_per_epoch': 624, 'pct_start': 0.3, 'anneal_strategy': 'cos', 'cycle_momentum': True, 'base_momentum': 0.85, 'max_momentum': 0.95, 'div_factor': 25.0, 'final_div_factor': 10000.0, 'three_phase': True, 'last_epoch': -1, 'verbose': False}, 'mapping': {'_target_': 'text_recognizer.data.word_piece_mapping.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}, 'model': {'_target_': 'text_recognizer.models.vqvae.VQVAELitModel', 'interval': 'step', 'monitor': 'val/loss'}, 'network': {'_target_': 'text_recognizer.networks.vqvae.VQVAE', 'in_channels': 1, 'res_channels': 32, 'num_residual_layers': 2, 'embedding_dim': 64, 'num_embeddings': 512, 'decay': 0.99, 'activation': 'mish'}, 'optimizer': {'_target_': 'madgrad.MADGRAD', 'lr': 0.01, 'momentum': 0.9, 'weight_decay': 0, 'eps': 1e-06}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'stochastic_weight_avg': False, 'auto_scale_batch_size': 'binsearch', 'auto_lr_find': False, 'gradient_clip_val': 0, 'fast_dev_run': False, 'gpus': 1, 'precision': 16, 'max_epochs': 64, 'terminate_on_nan': True, 'weights_summary': 'top', 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'limit_test_batches': 1.0, 'resume_from_checkpoint': None}, 'seed': 4711, 'tune': False, 'train': True, 'test': True, 'logging': 'INFO', 'work_dir': '${hydra:runtime.cwd}', 'debug': False, 'print_config': True, 'ignore_warnings': True}\n"
+ "{'callbacks': {'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': '{epoch:02d}'}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_code_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCodeAsArtifact', 'project_dir': '${work_dir}/text_recognizer'}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_image_reconstruction': {'_target_': 'callbacks.wandb_callbacks.LogReconstuctedImages', 'num_samples': 8}}, 'criterion': {'_target_': 'torch.nn.MSELoss', 'reduction': 'mean'}, 'datamodule': {'_target_': 'text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs', 'batch_size': 32, 'num_workers': 12, 'train_fraction': 0.8, 'augment': True, 'pin_memory': False, 'word_pieces': True}, 'logger': {'wandb': {'_target_': 'pytorch_lightning.loggers.wandb.WandbLogger', 'project': 'text-recognizer', 'name': None, 'save_dir': '.', 'offline': False, 'id': None, 'log_model': False, 'prefix': '', 'job_type': 'train', 'group': '', 'tags': []}}, 'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.OneCycleLR', 'max_lr': 0.001, 'total_steps': None, 'epochs': 64, 'steps_per_epoch': 624, 'pct_start': 0.3, 'anneal_strategy': 'cos', 'cycle_momentum': True, 'base_momentum': 0.85, 'max_momentum': 0.95, 'div_factor': 25.0, 'final_div_factor': 10000.0, 'three_phase': True, 'last_epoch': -1, 'verbose': False}, 'mapping': {'_target_': 'text_recognizer.data.word_piece_mapping.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}, 'model': {'_target_': 'text_recognizer.models.vqvae.VQVAELitModel', 'interval': 'step', 'monitor': 'val/loss', 'latent_loss_weight': 0.25}, 'network': {'_target_': 'text_recognizer.networks.vqvae.VQVAE', 'in_channels': 1, 'res_channels': 32, 'num_residual_layers': 2, 'embedding_dim': 64, 'num_embeddings': 512, 'decay': 0.99, 'activation': 'mish'}, 'optimizer': {'_target_': 'madgrad.MADGRAD', 'lr': 0.01, 'momentum': 0.9, 'weight_decay': 0, 'eps': 1e-06}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'stochastic_weight_avg': False, 'auto_scale_batch_size': 'binsearch', 'auto_lr_find': False, 'gradient_clip_val': 0, 'fast_dev_run': False, 'gpus': 1, 'precision': 16, 'max_epochs': 64, 'terminate_on_nan': True, 'weights_summary': 'top', 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'limit_test_batches': 1.0, 'resume_from_checkpoint': None}, 'seed': 4711, 'tune': False, 'train': True, 'test': True, 'logging': 'INFO', 'work_dir': '${hydra:runtime.cwd}', 'debug': False, 'print_config': True, 'ignore_warnings': True}\n"
]
}
],
@@ -196,7 +197,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2021-08-04 04:49:04.188 | DEBUG | text_recognizer.data.word_piece_mapping:__init__:37 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n"
+ "2021-08-04 05:07:26.480 | DEBUG | text_recognizer.data.word_piece_mapping:__init__:37 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n"
]
}
],
@@ -206,7 +207,7 @@
},
{
"cell_type": "code",
- "execution_count": 35,
+ "execution_count": 5,
"id": "969ba3be-d78f-4b1e-b522-ea8a42669e86",
"metadata": {},
"outputs": [],
@@ -216,7 +217,7 @@
},
{
"cell_type": "code",
- "execution_count": 36,
+ "execution_count": 6,
"id": "6147cd3e-0ad1-490f-917d-21be9bb8ce1c",
"metadata": {},
"outputs": [],
@@ -226,7 +227,7 @@
},
{
"cell_type": "code",
- "execution_count": 37,
+ "execution_count": 7,
"id": "a0ecea0c-abaf-4d5d-a13d-c085c1e4d282",
"metadata": {},
"outputs": [
@@ -236,7 +237,7 @@
"torch.Size([1, 64, 144, 160])"
]
},
- "execution_count": 37,
+ "execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@@ -247,7 +248,7 @@
},
{
"cell_type": "code",
- "execution_count": 38,
+ "execution_count": 8,
"id": "a7b9f249-7e5e-4f31-bbe1-cfd6d3701cf0",
"metadata": {},
"outputs": [
@@ -260,20 +261,83 @@
"torch.Size([512])\n",
"torch.Size([512])\n"
]
- },
+ }
+ ],
+ "source": [
+ "t, l = network(x)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "9a9450d2-f45d-4823-adac-68a8ea05ed1d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(0.0188, grad_fn=<AddBackward0>)"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "l"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "93b8c90f-788a-4095-aa7a-55b34f0ddaaf",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from torch.nn import functional as F\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "c9983788-2dae-4375-a821-a64cd1c68edf",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(0.5669, grad_fn=<AddBackward0>)"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "F.mse_loss(x, t) + l"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "29b128ca-80b7-481e-bb3c-44f109c7d292",
+ "metadata": {},
+ "outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 1, 576, 640])"
]
},
- "execution_count": 38,
+ "execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "network(x)[0].shape"
+ "t.shape"
]
},
{