From d3afa310f77f47553586eeee58e3d3345a754e2c Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 4 Aug 2021 05:03:51 +0200 Subject: New VQVAE --- .../conf/callbacks/wandb_image_reconstructions.yaml | 3 +++ training/conf/callbacks/wandb_vae.yaml | 6 ++++++ training/conf/config.yaml | 2 ++ training/conf/experiment/vqvae.yaml | 20 ++++++++++++++++++++ training/conf/experiment/vqvae_experiment.yaml | 13 ------------- training/conf/model/lit_vqvae.yaml | 4 +++- training/conf/network/conv_transformer.yaml | 2 +- .../conf/network/decoder/transformer_decoder.yaml | 4 ++-- training/conf/network/vqvae.yaml | 21 ++++++++------------- 9 files changed, 45 insertions(+), 30 deletions(-) create mode 100644 training/conf/callbacks/wandb_vae.yaml create mode 100644 training/conf/experiment/vqvae.yaml delete mode 100644 training/conf/experiment/vqvae_experiment.yaml (limited to 'training/conf') diff --git a/training/conf/callbacks/wandb_image_reconstructions.yaml b/training/conf/callbacks/wandb_image_reconstructions.yaml index e69de29..6cc4ada 100644 --- a/training/conf/callbacks/wandb_image_reconstructions.yaml +++ b/training/conf/callbacks/wandb_image_reconstructions.yaml @@ -0,0 +1,3 @@ +log_image_reconstruction: + _target_: callbacks.wandb_callbacks.LogReconstuctedImages + num_samples: 8 diff --git a/training/conf/callbacks/wandb_vae.yaml b/training/conf/callbacks/wandb_vae.yaml new file mode 100644 index 0000000..609a8e8 --- /dev/null +++ b/training/conf/callbacks/wandb_vae.yaml @@ -0,0 +1,6 @@ +defaults: + - default + - wandb_watch + - wandb_code + - wandb_checkpoints + - wandb_image_reconstructions diff --git a/training/conf/config.yaml b/training/conf/config.yaml index 782bcbb..6b74502 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -1,3 +1,5 @@ +# @package _global_ + defaults: - callbacks: wandb_ocr - criterion: label_smoothing diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml new file mode 100644 index 0000000..13e5f34 --- /dev/null +++ b/training/conf/experiment/vqvae.yaml @@ -0,0 +1,20 @@ +# @package _global_ + +defaults: + - override /network: vqvae + - override /criterion: mse + - override /model: lit_vqvae + - override /callbacks: wandb_vae + +trainer: + max_epochs: 64 + +datamodule: + batch_size: 32 + +lr_scheduler: + epochs: 64 + steps_per_epoch: 624 + +optimizer: + lr: 1.0e-2 diff --git a/training/conf/experiment/vqvae_experiment.yaml b/training/conf/experiment/vqvae_experiment.yaml deleted file mode 100644 index 0858c3d..0000000 --- a/training/conf/experiment/vqvae_experiment.yaml +++ /dev/null @@ -1,13 +0,0 @@ -defaults: - - override /network: vqvae - - override /criterion: mse - - override /optimizer: madgrad - - override /lr_scheduler: one_cycle - - override /model: lit_vqvae - - override /dataset: iam_extended_paragraphs - - override /trainer: default - - override /callbacks: - - wandb - -load_checkpoint: null -logging: INFO diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml index b337fe6..8837573 100644 --- a/training/conf/model/lit_vqvae.yaml +++ b/training/conf/model/lit_vqvae.yaml @@ -1,2 +1,4 @@ _target_: text_recognizer.models.vqvae.VQVAELitModel -mapping: sentence_piece +interval: step +monitor: val/loss +latent_loss_weight: 0.25 diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index f76e892..d3a3b0f 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -4,7 +4,7 @@ defaults: _target_: text_recognizer.networks.conv_transformer.ConvTransformer input_dims: [1, 576, 640] -hidden_dim: 96 +hidden_dim: 128 dropout_rate: 0.2 num_classes: 1006 pad_index: 1002 diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index eb80f64..c326c04 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -2,12 +2,12 @@ defaults: - rotary_emb: null _target_: text_recognizer.networks.transformer.Decoder -dim: 96 +dim: 128 depth: 2 num_heads: 8 attn_fn: text_recognizer.networks.transformer.attention.Attention attn_kwargs: - dim_head: 16 + dim_head: 64 dropout_rate: 0.2 norm_fn: torch.nn.LayerNorm ff_fn: text_recognizer.networks.transformer.mlp.FeedForward diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml index 22eebf8..5a5c066 100644 --- a/training/conf/network/vqvae.yaml +++ b/training/conf/network/vqvae.yaml @@ -1,13 +1,8 @@ -type: VQVAE -args: - in_channels: 1 - channels: [64, 96] - kernel_sizes: [4, 4] - strides: [2, 2] - num_residual_layers: 2 - embedding_dim: 64 - num_embeddings: 256 - upsampling: null - beta: 0.25 - activation: leaky_relu - dropout_rate: 0.2 +_target_: text_recognizer.networks.vqvae.VQVAE +in_channels: 1 +res_channels: 32 +num_residual_layers: 2 +embedding_dim: 64 +num_embeddings: 512 +decay: 0.99 +activation: mish -- cgit v1.2.3-70-g09d2