summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-19 21:05:51 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-19 21:05:51 +0200
commit53dce464914b2ce3817a7cec9441f7cfa5048248 (patch)
tree18688185b9ee158a7f80aef9d17bb59221ca07ee /training
parent6dd33e5a087159dcbabb845f167279778b2a8ea5 (diff)
Update to experiments configs
Diffstat (limited to 'training')
-rw-r--r--training/conf/experiment/vq_htr_char.yaml74
-rw-r--r--training/conf/experiment/vqgan.yaml33
-rw-r--r--training/conf/experiment/vqvae.yaml2
-rw-r--r--training/conf/network/decoder/vae_decoder.yaml5
-rw-r--r--training/conf/network/encoder/vae_encoder.yaml5
5 files changed, 14 insertions, 105 deletions
diff --git a/training/conf/experiment/vq_htr_char.yaml b/training/conf/experiment/vq_htr_char.yaml
deleted file mode 100644
index b34dd11..0000000
--- a/training/conf/experiment/vq_htr_char.yaml
+++ /dev/null
@@ -1,74 +0,0 @@
-# @package _global_
-
-defaults:
- - override /mapping: null
- - override /network: null
- - override /model: null
-
-mapping:
- _target_: text_recognizer.data.emnist_mapping.EmnistMapping
- extra_symbols: [ "\n" ]
-
-datamodule:
- word_pieces: false
- batch_size: 8
-
-criterion:
- ignore_index: 3
-
-network:
- _target_: text_recognizer.networks.vq_transformer.VqTransformer
- input_dims: [1, 576, 640]
- encoder_dim: 64
- hidden_dim: 64
- dropout_rate: 0.1
- num_classes: 58
- pad_index: 3
- no_grad: false
- encoder:
- _target_: text_recognizer.networks.vqvae.vqvae.VQVAE
- hidden_dim: 128
- embedding_dim: 64
- num_embeddings: 1024
- decay: 0.99
- encoder:
- _target_: text_recognizer.networks.vqvae.encoder.Encoder
- in_channels: 1
- hidden_dim: 64
- channels_multipliers: [1, 1, 2, 2]
- dropout_rate: 0.0
- decoder:
- _target_: text_recognizer.networks.vqvae.decoder.Decoder
- out_channels: 1
- hidden_dim: 64
- channels_multipliers: [2, 2, 1, 1]
- dropout_rate: 0.0
- decoder:
- _target_: text_recognizer.networks.transformer.Decoder
- dim: 64
- depth: 2
- num_heads: 4
- attn_fn: text_recognizer.networks.transformer.attention.Attention
- attn_kwargs:
- dim_head: 32
- dropout_rate: 0.2
- norm_fn: torch.nn.LayerNorm
- ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
- ff_kwargs:
- dim_out: null
- expansion_factor: 4
- glu: true
- dropout_rate: 0.2
- cross_attend: true
- pre_norm: true
- rotary_emb: null
-
- # pretrained_encoder_path: "training/logs/runs/2021-09-13/08-35-57/checkpoints/epoch=98.ckpt"
-
-model:
- _target_: text_recognizer.models.vq_transformer.VqTransformerLitModel
- start_token: <s>
- end_token: <e>
- pad_token: <p>
- max_output_len: 682
- # max_output_len: 451
diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml
index 6c78deb..34886ec 100644
--- a/training/conf/experiment/vqgan.yaml
+++ b/training/conf/experiment/vqgan.yaml
@@ -18,14 +18,14 @@ criterion:
in_channels: 1
num_channels: 64
num_layers: 3
- vq_loss_weight: 0.25
- discriminator_weight: 1.0
+ vq_loss_weight: 1.0
+ discriminator_weight: 0.8
discriminator_factor: 1.0
- discriminator_iter_start: 5e2
+ discriminator_iter_start: 7e4
datamodule:
batch_size: 8
- resize: [288, 320]
+ # resize: [288, 320]
lr_schedulers:
generator:
@@ -36,25 +36,6 @@ lr_schedulers:
interval: epoch
monitor: val/loss
-# _target_: torch.optim.lr_scheduler.OneCycleLR
-# max_lr: 3.0e-4
-# total_steps: null
-# epochs: 100
-# steps_per_epoch: 2496
-# pct_start: 0.1
-# anneal_strategy: cos
-# cycle_momentum: true
-# base_momentum: 0.85
-# max_momentum: 0.95
-# div_factor: 25
-# final_div_factor: 1.0e4
-# three_phase: true
-# last_epoch: -1
-# verbose: false
-
-# # Non-class arguments
-# interval: step
-# monitor: val/loss
# discriminator:
# _target_: torch.optim.lr_scheduler.CosineAnnealingLR
@@ -86,6 +67,6 @@ optimizers:
trainer:
max_epochs: 128
- limit_train_batches: 0.1
- limit_val_batches: 0.1
- # gradient_clip_val: 100
+ # limit_train_batches: 0.1
+ # limit_val_batches: 0.1
+ gradient_clip_val: 100
diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml
index d9fa2c4..6e42690 100644
--- a/training/conf/experiment/vqvae.yaml
+++ b/training/conf/experiment/vqvae.yaml
@@ -43,7 +43,7 @@ optimizers:
trainer:
max_epochs: 128
- limit_train_batches: 0.01
+ limit_train_batches: 0.1
limit_val_batches: 0.1
datamodule:
diff --git a/training/conf/network/decoder/vae_decoder.yaml b/training/conf/network/decoder/vae_decoder.yaml
index 2053544..8b5502d 100644
--- a/training/conf/network/decoder/vae_decoder.yaml
+++ b/training/conf/network/decoder/vae_decoder.yaml
@@ -1,6 +1,7 @@
_target_: text_recognizer.networks.vqvae.decoder.Decoder
out_channels: 1
hidden_dim: 32
-channels_multipliers: [4, 2, 1]
+channels_multipliers: [4, 4, 2, 1]
dropout_rate: 0.0
-activation: leaky_relu
+activation: mish
+use_norm: true
diff --git a/training/conf/network/encoder/vae_encoder.yaml b/training/conf/network/encoder/vae_encoder.yaml
index 2ea3adf..33ae0b9 100644
--- a/training/conf/network/encoder/vae_encoder.yaml
+++ b/training/conf/network/encoder/vae_encoder.yaml
@@ -1,6 +1,7 @@
_target_: text_recognizer.networks.vqvae.encoder.Encoder
in_channels: 1
hidden_dim: 32
-channels_multipliers: [1, 2, 4]
+channels_multipliers: [1, 2, 4, 4]
dropout_rate: 0.0
-activation: leaky_relu
+activation: mish
+use_norm: true