diff options
-rw-r--r-- | training/conf/experiment/conv_transformer_lines.yaml | 26 | ||||
-rw-r--r-- | training/conf/experiment/conv_transformer_paragraphs.yaml | 51 | ||||
-rw-r--r-- | training/conf/experiment/vqgan.yaml | 89 | ||||
-rw-r--r-- | training/conf/network/vqvae.yaml | 17 |
4 files changed, 92 insertions, 91 deletions
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml index 11646ca..20e369e 100644 --- a/training/conf/experiment/conv_transformer_lines.yaml +++ b/training/conf/experiment/conv_transformer_lines.yaml @@ -10,7 +10,7 @@ defaults: - override /lr_schedulers: null - override /optimizers: null -epochs: &epochs 620 +epochs: &epochs 300 ignore_index: &ignore_index 3 num_classes: &num_classes 57 max_output_len: &max_output_len 89 @@ -27,7 +27,7 @@ callbacks: stochastic_weight_averaging: _target_: pytorch_lightning.callbacks.StochasticWeightAveraging swa_epoch_start: 0.75 - swa_lrs: 1.0e-5 + swa_lrs: 1.0e-4 annealing_epochs: 10 annealing_strategy: cos device: null @@ -43,15 +43,15 @@ optimizers: lr_schedulers: network: - _target_: torch.optim.lr_scheduler.CosineAnnealingLR - T_max: *epochs - eta_min: 1.0e-5 - last_epoch: -1 - interval: epoch - monitor: val/loss + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: *epochs + eta_min: 1.0e-4 + last_epoch: -1 + interval: epoch + monitor: val/loss datamodule: - batch_size: 32 + batch_size: 16 num_workers: 12 train_fraction: 0.9 pin_memory: true @@ -64,7 +64,7 @@ rotary_embedding: &rotary_embedding attn: &attn dim: &hidden_dim 256 - num_heads: 4 + num_heads: 6 dim_head: 64 dropout_rate: &dropout_rate 0.5 @@ -76,12 +76,12 @@ network: pad_index: *ignore_index encoder: _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet - arch: b3 + arch: b0 stochastic_dropout_rate: 0.2 bn_momentum: 0.99 bn_eps: 1.0e-3 decoder: - depth: 6 + depth: 3 _target_: text_recognizer.networks.transformer.layers.Decoder self_attn: _target_: text_recognizer.networks.transformer.attention.Attention @@ -106,7 +106,7 @@ network: pixel_pos_embedding: _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding dim: *hidden_dim - shape: [1, 32] + shape: [3, 64] model: _target_: text_recognizer.models.transformer.TransformerLitModel diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index 00ad389..d2916e1 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -10,8 +10,7 @@ defaults: - override /lr_schedulers: null - override /optimizers: null - -epochs: &epochs 720 +epochs: &epochs 512 ignore_index: &ignore_index 3 num_classes: &num_classes 58 max_output_len: &max_output_len 682 @@ -29,7 +28,7 @@ callbacks: stochastic_weight_averaging: _target_: pytorch_lightning.callbacks.StochasticWeightAveraging swa_epoch_start: 0.75 - swa_lrs: 1.0e-5 + swa_lrs: 3.0e-5 annealing_epochs: 10 annealing_strategy: cos device: null @@ -37,7 +36,7 @@ callbacks: optimizers: madgrad: _target_: madgrad.MADGRAD - lr: 1.0e-4 + lr: 3.0e-4 momentum: 0.9 weight_decay: 5.0e-6 eps: 1.0e-6 @@ -45,27 +44,16 @@ optimizers: lr_schedulers: network: - _target_: torch.optim.lr_scheduler.OneCycleLR - max_lr: 1.0e-4 - total_steps: null - epochs: *epochs - steps_per_epoch: 1264 - pct_start: 0.01 - anneal_strategy: cos - cycle_momentum: true - base_momentum: 0.85 - max_momentum: 0.95 - div_factor: 25 - final_div_factor: 1.0e2 - three_phase: false + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: *epochs + eta_min: 1.0e-5 last_epoch: -1 - verbose: false - interval: step + interval: epoch monitor: val/loss datamodule: _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs - batch_size: 6 + batch_size: 4 num_workers: 12 train_fraction: 0.8 pin_memory: true @@ -77,27 +65,25 @@ rotary_embedding: &rotary_embedding dim: 64 attn: &attn - dim: 192 + dim: &hidden_dim 192 num_heads: 4 dim_head: 64 - dropout_rate: 0.05 + dropout_rate: &dropout_rate 0.5 network: _target_: text_recognizer.networks.conv_transformer.ConvTransformer input_dims: [1, 576, 640] - hidden_dim: &hidden_dim 192 + hidden_dim: *hidden_dim num_classes: *num_classes pad_index: *ignore_index encoder: _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet - arch: b0 - out_channels: 1280 + arch: b1 stochastic_dropout_rate: 0.2 bn_momentum: 0.99 bn_eps: 1.0e-3 decoder: - depth: 3 - local_depth: 2 + depth: 6 _target_: text_recognizer.networks.transformer.layers.Decoder self_attn: _target_: text_recognizer.networks.transformer.attention.Attention @@ -108,13 +94,6 @@ network: _target_: text_recognizer.networks.transformer.attention.Attention << : *attn causal: false - local_self_attn: - _target_: text_recognizer.networks.transformer.local_attention.LocalAttention - << : *attn - window_size: 31 - look_back: 1 - autopad: true - << : *rotary_embedding norm: _target_: text_recognizer.networks.transformer.norm.ScaleNorm normalized_shape: *hidden_dim @@ -124,7 +103,7 @@ network: dim_out: null expansion_factor: 4 glu: true - dropout_rate: 0.05 + dropout_rate: *dropout_rate pre_norm: true pixel_pos_embedding: _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding @@ -155,5 +134,5 @@ trainer: limit_val_batches: 1.0 limit_test_batches: 1.0 resume_from_checkpoint: null - accumulate_grad_batches: 4 + accumulate_grad_batches: 2 overfit_batches: 0 diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml index 98f3346..726757f 100644 --- a/training/conf/experiment/vqgan.yaml +++ b/training/conf/experiment/vqgan.yaml @@ -1,18 +1,26 @@ +# @package _global_ + defaults: - override /network: vqvae - override /criterion: null - override /model: lit_vqgan - - override /callbacks: wandb_vae + - override /callbacks: vae - override /optimizers: null - override /lr_schedulers: null +epochs: &epochs 100 +ignore_index: &ignore_index 3 +num_classes: &num_classes 58 +max_output_len: &max_output_len 682 +summary: [[1, 1, 576, 640]] + criterion: - _target_: text_recognizer.criterions.vqgan_loss.VQGANLoss + _target_: text_recognizer.criterion.vqgan_loss.VQGANLoss reconstruction_loss: _target_: torch.nn.BCEWithLogitsLoss reduction: mean discriminator: - _target_: text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator + _target_: text_recognizer.criterion.n_layer_discriminator.NLayerDiscriminator in_channels: 1 num_channels: 64 num_layers: 3 @@ -21,39 +29,35 @@ criterion: discriminator_factor: 1.0 discriminator_iter_start: 8.0e4 +mapping: &mapping + mapping: + _target_: text_recognizer.data.mappings.emnist.EmnistMapping + extra_symbols: [ "\n" ] + datamodule: - batch_size: 12 - # resize: [288, 320] - augment: false + _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs + batch_size: 4 + num_workers: 12 + train_fraction: 0.9 + pin_memory: true + << : *mapping lr_schedulers: - generator: - _target_: torch.optim.lr_scheduler.OneCycleLR - max_lr: 3.0e-4 - total_steps: null - epochs: 64 - steps_per_epoch: 1685 - pct_start: 0.3 - anneal_strategy: cos - cycle_momentum: true - base_momentum: 0.85 - max_momentum: 0.95 - div_factor: 25.0 - final_div_factor: 10000.0 - three_phase: true + network: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: *epochs + eta_min: 1.0e-5 last_epoch: -1 - verbose: false - interval: step + interval: epoch monitor: val/loss -# discriminator: -# _target_: torch.optim.lr_scheduler.CosineAnnealingLR -# T_max: 64 -# eta_min: 0.0 -# last_epoch: -1 -# -# interval: epoch -# monitor: val/loss + discriminator: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: *epochs + eta_min: 1.0e-5 + last_epoch: -1 + interval: epoch + monitor: val/loss optimizers: generator: @@ -75,11 +79,20 @@ optimizers: parameters: loss_fn.discriminator trainer: - max_epochs: 64 - # limit_train_batches: 0.1 - # limit_val_batches: 0.1 - # gradient_clip_val: 100 - -# tune: false -# train: true -# test: false + _target_: pytorch_lightning.Trainer + stochastic_weight_avg: false + auto_scale_batch_size: binsearch + auto_lr_find: false + gradient_clip_val: 0 + fast_dev_run: false + gpus: 1 + precision: 16 + max_epochs: *epochs + terminate_on_nan: true + weights_summary: null + limit_train_batches: 1.0 + limit_val_batches: 1.0 + limit_test_batches: 1.0 + resume_from_checkpoint: null + accumulate_grad_batches: 2 + overfit_batches: 0 diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml index 8210f04..22f786f 100644 --- a/training/conf/network/vqvae.yaml +++ b/training/conf/network/vqvae.yaml @@ -3,7 +3,16 @@ defaults: - decoder: vae_decoder _target_: text_recognizer.networks.vqvae.vqvae.VQVAE -hidden_dim: 128 -embedding_dim: 32 -num_embeddings: 8192 -decay: 0.99 +quantizer: + _target_: text_recognizer.networks.quantizer.quantizer.VectorQuantizer + input_dim: 128 + codebook: + _target_: text_recognizer.networks.quantizer.codebook.CosineSimilarityCodebook + dim: 8 + codebook_size: 512 + kmeans_init: true + kmeans_iters: 10 + decay: 0.8 + eps: 1.0e-5 + threshold_dead: 2 + commitment: 1.0 |