summaryrefslogtreecommitdiff
path: root/training/conf
diff options
context:
space:
mode:
Diffstat (limited to 'training/conf')
-rw-r--r--training/conf/callbacks/wandb_image_reconstructions.yaml3
-rw-r--r--training/conf/callbacks/wandb_vae.yaml6
-rw-r--r--training/conf/config.yaml2
-rw-r--r--training/conf/experiment/vqvae.yaml20
-rw-r--r--training/conf/experiment/vqvae_experiment.yaml13
-rw-r--r--training/conf/model/lit_vqvae.yaml4
-rw-r--r--training/conf/network/conv_transformer.yaml2
-rw-r--r--training/conf/network/decoder/transformer_decoder.yaml4
-rw-r--r--training/conf/network/vqvae.yaml21
9 files changed, 45 insertions, 30 deletions
diff --git a/training/conf/callbacks/wandb_image_reconstructions.yaml b/training/conf/callbacks/wandb_image_reconstructions.yaml
index e69de29..6cc4ada 100644
--- a/training/conf/callbacks/wandb_image_reconstructions.yaml
+++ b/training/conf/callbacks/wandb_image_reconstructions.yaml
@@ -0,0 +1,3 @@
+log_image_reconstruction:
+ _target_: callbacks.wandb_callbacks.LogReconstuctedImages
+ num_samples: 8
diff --git a/training/conf/callbacks/wandb_vae.yaml b/training/conf/callbacks/wandb_vae.yaml
new file mode 100644
index 0000000..609a8e8
--- /dev/null
+++ b/training/conf/callbacks/wandb_vae.yaml
@@ -0,0 +1,6 @@
+defaults:
+ - default
+ - wandb_watch
+ - wandb_code
+ - wandb_checkpoints
+ - wandb_image_reconstructions
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index 782bcbb..6b74502 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -1,3 +1,5 @@
+# @package _global_
+
defaults:
- callbacks: wandb_ocr
- criterion: label_smoothing
diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml
new file mode 100644
index 0000000..13e5f34
--- /dev/null
+++ b/training/conf/experiment/vqvae.yaml
@@ -0,0 +1,20 @@
+# @package _global_
+
+defaults:
+ - override /network: vqvae
+ - override /criterion: mse
+ - override /model: lit_vqvae
+ - override /callbacks: wandb_vae
+
+trainer:
+ max_epochs: 64
+
+datamodule:
+ batch_size: 32
+
+lr_scheduler:
+ epochs: 64
+ steps_per_epoch: 624
+
+optimizer:
+ lr: 1.0e-2
diff --git a/training/conf/experiment/vqvae_experiment.yaml b/training/conf/experiment/vqvae_experiment.yaml
deleted file mode 100644
index 0858c3d..0000000
--- a/training/conf/experiment/vqvae_experiment.yaml
+++ /dev/null
@@ -1,13 +0,0 @@
-defaults:
- - override /network: vqvae
- - override /criterion: mse
- - override /optimizer: madgrad
- - override /lr_scheduler: one_cycle
- - override /model: lit_vqvae
- - override /dataset: iam_extended_paragraphs
- - override /trainer: default
- - override /callbacks:
- - wandb
-
-load_checkpoint: null
-logging: INFO
diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml
index b337fe6..8837573 100644
--- a/training/conf/model/lit_vqvae.yaml
+++ b/training/conf/model/lit_vqvae.yaml
@@ -1,2 +1,4 @@
_target_: text_recognizer.models.vqvae.VQVAELitModel
-mapping: sentence_piece
+interval: step
+monitor: val/loss
+latent_loss_weight: 0.25
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index f76e892..d3a3b0f 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -4,7 +4,7 @@ defaults:
_target_: text_recognizer.networks.conv_transformer.ConvTransformer
input_dims: [1, 576, 640]
-hidden_dim: 96
+hidden_dim: 128
dropout_rate: 0.2
num_classes: 1006
pad_index: 1002
diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml
index eb80f64..c326c04 100644
--- a/training/conf/network/decoder/transformer_decoder.yaml
+++ b/training/conf/network/decoder/transformer_decoder.yaml
@@ -2,12 +2,12 @@ defaults:
- rotary_emb: null
_target_: text_recognizer.networks.transformer.Decoder
-dim: 96
+dim: 128
depth: 2
num_heads: 8
attn_fn: text_recognizer.networks.transformer.attention.Attention
attn_kwargs:
- dim_head: 16
+ dim_head: 64
dropout_rate: 0.2
norm_fn: torch.nn.LayerNorm
ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml
index 22eebf8..5a5c066 100644
--- a/training/conf/network/vqvae.yaml
+++ b/training/conf/network/vqvae.yaml
@@ -1,13 +1,8 @@
-type: VQVAE
-args:
- in_channels: 1
- channels: [64, 96]
- kernel_sizes: [4, 4]
- strides: [2, 2]
- num_residual_layers: 2
- embedding_dim: 64
- num_embeddings: 256
- upsampling: null
- beta: 0.25
- activation: leaky_relu
- dropout_rate: 0.2
+_target_: text_recognizer.networks.vqvae.VQVAE
+in_channels: 1
+res_channels: 32
+num_residual_layers: 2
+embedding_dim: 64
+num_embeddings: 512
+decay: 0.99
+activation: mish