diff options
Diffstat (limited to 'training/conf')
-rw-r--r-- | training/conf/callbacks/wandb_image_reconstructions.yaml | 3 | ||||
-rw-r--r-- | training/conf/callbacks/wandb_vae.yaml | 6 | ||||
-rw-r--r-- | training/conf/config.yaml | 2 | ||||
-rw-r--r-- | training/conf/experiment/vqvae.yaml | 20 | ||||
-rw-r--r-- | training/conf/experiment/vqvae_experiment.yaml | 13 | ||||
-rw-r--r-- | training/conf/model/lit_vqvae.yaml | 4 | ||||
-rw-r--r-- | training/conf/network/conv_transformer.yaml | 2 | ||||
-rw-r--r-- | training/conf/network/decoder/transformer_decoder.yaml | 4 | ||||
-rw-r--r-- | training/conf/network/vqvae.yaml | 21 |
9 files changed, 45 insertions, 30 deletions
diff --git a/training/conf/callbacks/wandb_image_reconstructions.yaml b/training/conf/callbacks/wandb_image_reconstructions.yaml index e69de29..6cc4ada 100644 --- a/training/conf/callbacks/wandb_image_reconstructions.yaml +++ b/training/conf/callbacks/wandb_image_reconstructions.yaml @@ -0,0 +1,3 @@ +log_image_reconstruction: + _target_: callbacks.wandb_callbacks.LogReconstuctedImages + num_samples: 8 diff --git a/training/conf/callbacks/wandb_vae.yaml b/training/conf/callbacks/wandb_vae.yaml new file mode 100644 index 0000000..609a8e8 --- /dev/null +++ b/training/conf/callbacks/wandb_vae.yaml @@ -0,0 +1,6 @@ +defaults: + - default + - wandb_watch + - wandb_code + - wandb_checkpoints + - wandb_image_reconstructions diff --git a/training/conf/config.yaml b/training/conf/config.yaml index 782bcbb..6b74502 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -1,3 +1,5 @@ +# @package _global_ + defaults: - callbacks: wandb_ocr - criterion: label_smoothing diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml new file mode 100644 index 0000000..13e5f34 --- /dev/null +++ b/training/conf/experiment/vqvae.yaml @@ -0,0 +1,20 @@ +# @package _global_ + +defaults: + - override /network: vqvae + - override /criterion: mse + - override /model: lit_vqvae + - override /callbacks: wandb_vae + +trainer: + max_epochs: 64 + +datamodule: + batch_size: 32 + +lr_scheduler: + epochs: 64 + steps_per_epoch: 624 + +optimizer: + lr: 1.0e-2 diff --git a/training/conf/experiment/vqvae_experiment.yaml b/training/conf/experiment/vqvae_experiment.yaml deleted file mode 100644 index 0858c3d..0000000 --- a/training/conf/experiment/vqvae_experiment.yaml +++ /dev/null @@ -1,13 +0,0 @@ -defaults: - - override /network: vqvae - - override /criterion: mse - - override /optimizer: madgrad - - override /lr_scheduler: one_cycle - - override /model: lit_vqvae - - override /dataset: iam_extended_paragraphs - - override /trainer: default - - override /callbacks: - - wandb - -load_checkpoint: null -logging: INFO diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml index b337fe6..8837573 100644 --- a/training/conf/model/lit_vqvae.yaml +++ b/training/conf/model/lit_vqvae.yaml @@ -1,2 +1,4 @@ _target_: text_recognizer.models.vqvae.VQVAELitModel -mapping: sentence_piece +interval: step +monitor: val/loss +latent_loss_weight: 0.25 diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index f76e892..d3a3b0f 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -4,7 +4,7 @@ defaults: _target_: text_recognizer.networks.conv_transformer.ConvTransformer input_dims: [1, 576, 640] -hidden_dim: 96 +hidden_dim: 128 dropout_rate: 0.2 num_classes: 1006 pad_index: 1002 diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index eb80f64..c326c04 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -2,12 +2,12 @@ defaults: - rotary_emb: null _target_: text_recognizer.networks.transformer.Decoder -dim: 96 +dim: 128 depth: 2 num_heads: 8 attn_fn: text_recognizer.networks.transformer.attention.Attention attn_kwargs: - dim_head: 16 + dim_head: 64 dropout_rate: 0.2 norm_fn: torch.nn.LayerNorm ff_fn: text_recognizer.networks.transformer.mlp.FeedForward diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml index 22eebf8..5a5c066 100644 --- a/training/conf/network/vqvae.yaml +++ b/training/conf/network/vqvae.yaml @@ -1,13 +1,8 @@ -type: VQVAE -args: - in_channels: 1 - channels: [64, 96] - kernel_sizes: [4, 4] - strides: [2, 2] - num_residual_layers: 2 - embedding_dim: 64 - num_embeddings: 256 - upsampling: null - beta: 0.25 - activation: leaky_relu - dropout_rate: 0.2 +_target_: text_recognizer.networks.vqvae.VQVAE +in_channels: 1 +res_channels: 32 +num_residual_layers: 2 +embedding_dim: 64 +num_embeddings: 512 +decay: 0.99 +activation: mish |