diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-04 15:15:26 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-04 15:15:26 +0200 |
commit | 04c40f790e405ced6e6b90cf0a8aea268b9345c4 (patch) | |
tree | d5e05ee09fa99ee8d56d5373bde18626274a1fdd | |
parent | d3afa310f77f47553586eeee58e3d3345a754e2c (diff) |
Add char htr experiment, rename from ocr to htr, vqvae loss collapses
-rw-r--r-- | notebooks/05c-test-model-end-to-end.ipynb | 84 | ||||
-rw-r--r-- | text_recognizer/models/vqvae.py | 11 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/quantizer.py | 2 | ||||
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 2 | ||||
-rw-r--r-- | training/conf/callbacks/wandb_htr.yaml (renamed from training/conf/callbacks/wandb_ocr.yaml) | 0 | ||||
-rw-r--r-- | training/conf/callbacks/wandb_htr_predictions.yaml (renamed from training/conf/callbacks/wandb_ocr_predictions.yaml) | 0 | ||||
-rw-r--r-- | training/conf/config.yaml | 2 | ||||
-rw-r--r-- | training/conf/experiment/htr_char.yaml | 12 | ||||
-rw-r--r-- | training/conf/experiment/vqvae.yaml | 3 | ||||
-rw-r--r-- | training/conf/mapping/characters.yaml (renamed from training/conf/mapping/emnist.yaml) | 0 | ||||
-rw-r--r-- | training/conf/network/decoder/transformer_decoder.yaml | 2 | ||||
-rw-r--r-- | training/conf/trainer/default.yaml | 2 |
12 files changed, 96 insertions, 24 deletions
diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb index 850d205..913eafd 100644 --- a/notebooks/05c-test-model-end-to-end.ipynb +++ b/notebooks/05c-test-model-end-to-end.ipynb @@ -133,6 +133,7 @@ " _target_: text_recognizer.models.vqvae.VQVAELitModel\n", " interval: step\n", " monitor: val/loss\n", + " latent_loss_weight: 0.25\n", "network:\n", " _target_: text_recognizer.networks.vqvae.VQVAE\n", " in_channels: 1\n", @@ -174,7 +175,7 @@ "print_config: true\n", "ignore_warnings: true\n", "\n", - "{'callbacks': {'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': '{epoch:02d}'}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_code_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCodeAsArtifact', 'project_dir': '${work_dir}/text_recognizer'}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_image_reconstruction': {'_target_': 'callbacks.wandb_callbacks.LogReconstuctedImages', 'num_samples': 8}}, 'criterion': {'_target_': 'torch.nn.MSELoss', 'reduction': 'mean'}, 'datamodule': {'_target_': 'text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs', 'batch_size': 32, 'num_workers': 12, 'train_fraction': 0.8, 'augment': True, 'pin_memory': False, 'word_pieces': True}, 'logger': {'wandb': {'_target_': 'pytorch_lightning.loggers.wandb.WandbLogger', 'project': 'text-recognizer', 'name': None, 'save_dir': '.', 'offline': False, 'id': None, 'log_model': False, 'prefix': '', 'job_type': 'train', 'group': '', 'tags': []}}, 'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.OneCycleLR', 'max_lr': 0.001, 'total_steps': None, 'epochs': 64, 'steps_per_epoch': 624, 'pct_start': 0.3, 'anneal_strategy': 'cos', 'cycle_momentum': True, 'base_momentum': 0.85, 'max_momentum': 0.95, 'div_factor': 25.0, 'final_div_factor': 10000.0, 'three_phase': True, 'last_epoch': -1, 'verbose': False}, 'mapping': {'_target_': 'text_recognizer.data.word_piece_mapping.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}, 'model': {'_target_': 'text_recognizer.models.vqvae.VQVAELitModel', 'interval': 'step', 'monitor': 'val/loss'}, 'network': {'_target_': 'text_recognizer.networks.vqvae.VQVAE', 'in_channels': 1, 'res_channels': 32, 'num_residual_layers': 2, 'embedding_dim': 64, 'num_embeddings': 512, 'decay': 0.99, 'activation': 'mish'}, 'optimizer': {'_target_': 'madgrad.MADGRAD', 'lr': 0.01, 'momentum': 0.9, 'weight_decay': 0, 'eps': 1e-06}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'stochastic_weight_avg': False, 'auto_scale_batch_size': 'binsearch', 'auto_lr_find': False, 'gradient_clip_val': 0, 'fast_dev_run': False, 'gpus': 1, 'precision': 16, 'max_epochs': 64, 'terminate_on_nan': True, 'weights_summary': 'top', 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'limit_test_batches': 1.0, 'resume_from_checkpoint': None}, 'seed': 4711, 'tune': False, 'train': True, 'test': True, 'logging': 'INFO', 'work_dir': '${hydra:runtime.cwd}', 'debug': False, 'print_config': True, 'ignore_warnings': True}\n" + "{'callbacks': {'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': '{epoch:02d}'}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_code_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCodeAsArtifact', 'project_dir': '${work_dir}/text_recognizer'}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_image_reconstruction': {'_target_': 'callbacks.wandb_callbacks.LogReconstuctedImages', 'num_samples': 8}}, 'criterion': {'_target_': 'torch.nn.MSELoss', 'reduction': 'mean'}, 'datamodule': {'_target_': 'text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs', 'batch_size': 32, 'num_workers': 12, 'train_fraction': 0.8, 'augment': True, 'pin_memory': False, 'word_pieces': True}, 'logger': {'wandb': {'_target_': 'pytorch_lightning.loggers.wandb.WandbLogger', 'project': 'text-recognizer', 'name': None, 'save_dir': '.', 'offline': False, 'id': None, 'log_model': False, 'prefix': '', 'job_type': 'train', 'group': '', 'tags': []}}, 'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.OneCycleLR', 'max_lr': 0.001, 'total_steps': None, 'epochs': 64, 'steps_per_epoch': 624, 'pct_start': 0.3, 'anneal_strategy': 'cos', 'cycle_momentum': True, 'base_momentum': 0.85, 'max_momentum': 0.95, 'div_factor': 25.0, 'final_div_factor': 10000.0, 'three_phase': True, 'last_epoch': -1, 'verbose': False}, 'mapping': {'_target_': 'text_recognizer.data.word_piece_mapping.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}, 'model': {'_target_': 'text_recognizer.models.vqvae.VQVAELitModel', 'interval': 'step', 'monitor': 'val/loss', 'latent_loss_weight': 0.25}, 'network': {'_target_': 'text_recognizer.networks.vqvae.VQVAE', 'in_channels': 1, 'res_channels': 32, 'num_residual_layers': 2, 'embedding_dim': 64, 'num_embeddings': 512, 'decay': 0.99, 'activation': 'mish'}, 'optimizer': {'_target_': 'madgrad.MADGRAD', 'lr': 0.01, 'momentum': 0.9, 'weight_decay': 0, 'eps': 1e-06}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'stochastic_weight_avg': False, 'auto_scale_batch_size': 'binsearch', 'auto_lr_find': False, 'gradient_clip_val': 0, 'fast_dev_run': False, 'gpus': 1, 'precision': 16, 'max_epochs': 64, 'terminate_on_nan': True, 'weights_summary': 'top', 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'limit_test_batches': 1.0, 'resume_from_checkpoint': None}, 'seed': 4711, 'tune': False, 'train': True, 'test': True, 'logging': 'INFO', 'work_dir': '${hydra:runtime.cwd}', 'debug': False, 'print_config': True, 'ignore_warnings': True}\n" ] } ], @@ -196,7 +197,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "2021-08-04 04:49:04.188 | DEBUG | text_recognizer.data.word_piece_mapping:__init__:37 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n" + "2021-08-04 05:07:26.480 | DEBUG | text_recognizer.data.word_piece_mapping:__init__:37 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n" ] } ], @@ -206,7 +207,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 5, "id": "969ba3be-d78f-4b1e-b522-ea8a42669e86", "metadata": {}, "outputs": [], @@ -216,7 +217,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 6, "id": "6147cd3e-0ad1-490f-917d-21be9bb8ce1c", "metadata": {}, "outputs": [], @@ -226,7 +227,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 7, "id": "a0ecea0c-abaf-4d5d-a13d-c085c1e4d282", "metadata": {}, "outputs": [ @@ -236,7 +237,7 @@ "torch.Size([1, 64, 144, 160])" ] }, - "execution_count": 37, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -247,7 +248,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 8, "id": "a7b9f249-7e5e-4f31-bbe1-cfd6d3701cf0", "metadata": {}, "outputs": [ @@ -260,20 +261,83 @@ "torch.Size([512])\n", "torch.Size([512])\n" ] - }, + } + ], + "source": [ + "t, l = network(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9a9450d2-f45d-4823-adac-68a8ea05ed1d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.0188, grad_fn=<AddBackward0>)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "l" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "93b8c90f-788a-4095-aa7a-55b34f0ddaaf", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.nn import functional as F\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c9983788-2dae-4375-a821-a64cd1c68edf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.5669, grad_fn=<AddBackward0>)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "F.mse_loss(x, t) + l" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "29b128ca-80b7-481e-bb3c-44f109c7d292", + "metadata": {}, + "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 1, 576, 640])" ] }, - "execution_count": 38, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "network(x)[0].shape" + "t.shape" ] }, { diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 5890fd9..7f79b78 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -1,11 +1,8 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Any, Dict, Union, Tuple, Type +from typing import Tuple import attr -from omegaconf import DictConfig -from torch import nn from torch import Tensor -import wandb from text_recognizer.models.base import BaseLitModel @@ -25,7 +22,7 @@ class VQVAELitModel(BaseLitModel): data, _ = batch reconstructions, vq_loss = self(data) loss = self.loss_fn(reconstructions, data) - loss += self.latent_loss_weight * vq_loss + loss = loss + self.latent_loss_weight * vq_loss self.log("train/loss", loss) return loss @@ -34,7 +31,7 @@ class VQVAELitModel(BaseLitModel): data, _ = batch reconstructions, vq_loss = self(data) loss = self.loss_fn(reconstructions, data) - loss += self.latent_loss_weight * vq_loss + loss = loss + self.latent_loss_weight * vq_loss self.log("val/loss", loss, prog_bar=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: @@ -42,5 +39,5 @@ class VQVAELitModel(BaseLitModel): data, _ = batch reconstructions, vq_loss = self(data) loss = self.loss_fn(reconstructions, data) - loss += self.latent_loss_weight * vq_loss + loss = loss + self.latent_loss_weight * vq_loss self.log("test/loss", loss) diff --git a/text_recognizer/networks/vqvae/quantizer.py b/text_recognizer/networks/vqvae/quantizer.py index 5e0b602..1b59e78 100644 --- a/text_recognizer/networks/vqvae/quantizer.py +++ b/text_recognizer/networks/vqvae/quantizer.py @@ -83,8 +83,6 @@ class VectorQuantizer(nn.Module): def compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None: batch_cluster_size = one_hot_encoding.sum(axis=0) batch_embedding_avg = (latent.t() @ one_hot_encoding).t() - print(batch_cluster_size.shape) - print(self.embedding._cluster_size.shape) self.embedding._cluster_size.data.mul_(self.decay).add_( batch_cluster_size, alpha=1 - self.decay ) diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index c750e4b..61d71df 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -119,7 +119,7 @@ class LogTextPredictions(Callback): ] experiment.log( - {f"OCR/{experiment.name}/{stage}": wandb.Table(data=data, columns=columns)} + {f"HTR/{experiment.name}/{stage}": wandb.Table(data=data, columns=columns)} ) def on_sanity_check_start( diff --git a/training/conf/callbacks/wandb_ocr.yaml b/training/conf/callbacks/wandb_htr.yaml index 9c9a6da..9c9a6da 100644 --- a/training/conf/callbacks/wandb_ocr.yaml +++ b/training/conf/callbacks/wandb_htr.yaml diff --git a/training/conf/callbacks/wandb_ocr_predictions.yaml b/training/conf/callbacks/wandb_htr_predictions.yaml index 573fa96..573fa96 100644 --- a/training/conf/callbacks/wandb_ocr_predictions.yaml +++ b/training/conf/callbacks/wandb_htr_predictions.yaml diff --git a/training/conf/config.yaml b/training/conf/config.yaml index 6b74502..c606366 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - callbacks: wandb_ocr + - callbacks: wandb_htr - criterion: label_smoothing - datamodule: iam_extended_paragraphs - hydra: default diff --git a/training/conf/experiment/htr_char.yaml b/training/conf/experiment/htr_char.yaml new file mode 100644 index 0000000..77126ae --- /dev/null +++ b/training/conf/experiment/htr_char.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - override /mapping: characters + +criterion: + ignore_index: 3 + +network: + num_classes: 89 + pad_index: 3 + max_output_len: 682 diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml index 13e5f34..699612e 100644 --- a/training/conf/experiment/vqvae.yaml +++ b/training/conf/experiment/vqvae.yaml @@ -8,6 +8,7 @@ defaults: trainer: max_epochs: 64 + gradient_clip_val: 0.25 datamodule: batch_size: 32 @@ -17,4 +18,4 @@ lr_scheduler: steps_per_epoch: 624 optimizer: - lr: 1.0e-2 + lr: 1.0e-3 diff --git a/training/conf/mapping/emnist.yaml b/training/conf/mapping/characters.yaml index 14e966b..14e966b 100644 --- a/training/conf/mapping/emnist.yaml +++ b/training/conf/mapping/characters.yaml diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index c326c04..bc0678b 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -4,7 +4,7 @@ defaults: _target_: text_recognizer.networks.transformer.Decoder dim: 128 depth: 2 -num_heads: 8 +num_heads: 4 attn_fn: text_recognizer.networks.transformer.attention.Attention attn_kwargs: dim_head: 64 diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml index c665adc..0fa9ce1 100644 --- a/training/conf/trainer/default.yaml +++ b/training/conf/trainer/default.yaml @@ -8,7 +8,7 @@ gpus: 1 precision: 16 max_epochs: 512 terminate_on_nan: true -weights_summary: top +weights_summary: full limit_train_batches: 1.0 limit_val_batches: 1.0 limit_test_batches: 1.0 |