diff options
| -rw-r--r-- | training/conf/config.yaml | 10 | ||||
| -rw-r--r-- | training/conf/datamodule/iam_extended_paragraphs.yaml | 1 | ||||
| -rw-r--r-- | training/conf/experiment/htr_char.yaml | 17 | ||||
| -rw-r--r-- | training/conf/experiment/vq_htr_char.yaml | 74 | ||||
| -rw-r--r-- | training/conf/experiment/vqgan.yaml | 36 | ||||
| -rw-r--r-- | training/conf/experiment/vqvae.yaml | 38 | ||||
| -rw-r--r-- | training/conf/model/lit_transformer.yaml | 2 | ||||
| -rw-r--r-- | training/conf/model/lit_vqvae.yaml | 1 | ||||
| -rw-r--r-- | training/conf/network/conv_transformer.yaml | 1 | ||||
| -rw-r--r-- | training/conf/network/decoder/vae_decoder.yaml | 5 | ||||
| -rw-r--r-- | training/conf/network/encoder/vae_encoder.yaml | 5 | ||||
| -rw-r--r-- | training/conf/network/vqvae.yaml | 6 | ||||
| -rw-r--r-- | training/conf/optimizers/madgrad.yaml | 2 | 
13 files changed, 147 insertions, 51 deletions
diff --git a/training/conf/config.yaml b/training/conf/config.yaml index 5897d87..9ed366f 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -7,8 +7,8 @@ defaults:    - hydra: default    - logger: wandb    - lr_schedulers:  -      - one_cycle -  - mapping: word_piece +      - cosine_annealing +  - mapping: characters # word_piece    - model: lit_transformer    - network: conv_transformer    - optimizers:  @@ -21,6 +21,12 @@ train: true  test: true  logging: INFO +# datamodule: +#   word_pieces: false + +# model: +#   max_output_len: 682 +  # path to original working directory  # hydra hijacks working directory by changing it to the current log directory,  # so it's useful to have this path as a special variable diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml index a2dd293..a0ffe56 100644 --- a/training/conf/datamodule/iam_extended_paragraphs.yaml +++ b/training/conf/datamodule/iam_extended_paragraphs.yaml @@ -5,3 +5,4 @@ train_fraction: 0.8  augment: true  pin_memory: false  word_pieces: true +resize: null diff --git a/training/conf/experiment/htr_char.yaml b/training/conf/experiment/htr_char.yaml deleted file mode 100644 index e51a116..0000000 --- a/training/conf/experiment/htr_char.yaml +++ /dev/null @@ -1,17 +0,0 @@ -# @package _global_ - -defaults: -  - override /mapping: characters - -datamodule: -  word_pieces: false - -criterion: -  ignore_index: 3 - -network: -  num_classes: 58 -  pad_index: 3 - -model: -  max_output_len: 682 diff --git a/training/conf/experiment/vq_htr_char.yaml b/training/conf/experiment/vq_htr_char.yaml new file mode 100644 index 0000000..b34dd11 --- /dev/null +++ b/training/conf/experiment/vq_htr_char.yaml @@ -0,0 +1,74 @@ +# @package _global_ + +defaults: +  - override /mapping: null +  - override /network: null +  - override /model: null + +mapping: +  _target_: text_recognizer.data.emnist_mapping.EmnistMapping +  extra_symbols: [ "\n" ] + +datamodule: +  word_pieces: false +  batch_size: 8 + +criterion: +  ignore_index: 3 + +network: +  _target_: text_recognizer.networks.vq_transformer.VqTransformer +  input_dims: [1, 576, 640] +  encoder_dim: 64 +  hidden_dim: 64 +  dropout_rate: 0.1 +  num_classes: 58 +  pad_index: 3 +  no_grad: false +  encoder: +    _target_: text_recognizer.networks.vqvae.vqvae.VQVAE +    hidden_dim: 128 +    embedding_dim: 64 +    num_embeddings: 1024 +    decay: 0.99 +    encoder: +      _target_: text_recognizer.networks.vqvae.encoder.Encoder +      in_channels: 1  +      hidden_dim: 64 +      channels_multipliers: [1, 1, 2, 2] +      dropout_rate: 0.0 +    decoder: +      _target_: text_recognizer.networks.vqvae.decoder.Decoder +      out_channels: 1  +      hidden_dim: 64 +      channels_multipliers: [2, 2, 1, 1] +      dropout_rate: 0.0 +  decoder: +    _target_: text_recognizer.networks.transformer.Decoder +    dim: 64 +    depth: 2 +    num_heads: 4 +    attn_fn: text_recognizer.networks.transformer.attention.Attention +    attn_kwargs: +      dim_head: 32 +      dropout_rate: 0.2 +    norm_fn: torch.nn.LayerNorm +    ff_fn: text_recognizer.networks.transformer.mlp.FeedForward +    ff_kwargs: +      dim_out: null +      expansion_factor: 4 +      glu: true +      dropout_rate: 0.2 +    cross_attend: true +    pre_norm: true +    rotary_emb: null + +  # pretrained_encoder_path: "training/logs/runs/2021-09-13/08-35-57/checkpoints/epoch=98.ckpt" + +model: +  _target_: text_recognizer.models.vq_transformer.VqTransformerLitModel +  start_token: <s> +  end_token: <e> +  pad_token: <p> +  max_output_len: 682 +  # max_output_len: 451 diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml index 9224bc7..6c78deb 100644 --- a/training/conf/experiment/vqgan.yaml +++ b/training/conf/experiment/vqgan.yaml @@ -2,7 +2,7 @@  defaults:    - override /network: vqvae -  - override /criterion: vqgan_loss +  - override /criterion: null    - override /model: lit_vqgan    - override /callbacks: wandb_vae    - override /optimizers: null @@ -11,7 +11,7 @@ defaults:  criterion:    _target_: text_recognizer.criterions.vqgan_loss.VQGANLoss    reconstruction_loss: -    _target_: torch.nn.L1Loss +    _target_: torch.nn.MSELoss      reduction: mean    discriminator:      _target_: text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator @@ -21,35 +21,41 @@ criterion:    vq_loss_weight: 0.25    discriminator_weight: 1.0    discriminator_factor: 1.0 -  discriminator_iter_start: 2.0e4 +  discriminator_iter_start: 5e2  datamodule: -  batch_size: 6 +  batch_size: 8 +  resize: [288, 320] -lr_schedulers: null +lr_schedulers: +  generator: +    _target_: torch.optim.lr_scheduler.CosineAnnealingLR +    T_max: 128 +    eta_min: 4.5e-6 +    last_epoch: -1 -# lr_schedulers: -#   generator: +    interval: epoch +    monitor: val/loss  #     _target_: torch.optim.lr_scheduler.OneCycleLR  #     max_lr: 3.0e-4  #     total_steps: null  #     epochs: 100 -#     steps_per_epoch: 3369 +#     steps_per_epoch: 2496  #     pct_start: 0.1  #     anneal_strategy: cos  #     cycle_momentum: true  #     base_momentum: 0.85  #     max_momentum: 0.95 -#     div_factor: 1.0e3 +#     div_factor: 25  #     final_div_factor: 1.0e4  #     three_phase: true  #     last_epoch: -1  #     verbose: false -#  +  #     # Non-class arguments  #     interval: step  #     monitor: val/loss -#  +   #   discriminator:  #     _target_: torch.optim.lr_scheduler.CosineAnnealingLR  #     T_max: 64 @@ -79,7 +85,7 @@ optimizers:      parameters: loss_fn.discriminator  trainer: -  max_epochs: 64 -  # gradient_clip_val: 1.0e1 - -summary: null +  max_epochs: 128 +  limit_train_batches: 0.1 +  limit_val_batches: 0.1 +  # gradient_clip_val: 100 diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml index d3db471..d9fa2c4 100644 --- a/training/conf/experiment/vqvae.yaml +++ b/training/conf/experiment/vqvae.yaml @@ -2,26 +2,52 @@  defaults:    - override /network: vqvae -  - override /criterion: mae +  - override /criterion: mse    - override /model: lit_vqvae    - override /callbacks: wandb_vae -  - override /lr_schedulers:  -      - cosine_annealing +  - override /optimizers: null +  # - override /lr_schedulers:  +      # - cosine_annealing + +# lr_schedulers: null +#   network: +#     _target_: torch.optim.lr_scheduler.OneCycleLR +#     max_lr: 1.0e-2 +#     total_steps: null +#     epochs: 100 +#     steps_per_epoch: 200 +#     pct_start: 0.1 +#     anneal_strategy: cos +#     cycle_momentum: true +#     base_momentum: 0.85 +#     max_momentum: 0.95 +#     div_factor: 25 +#     final_div_factor: 1.0e4 +#     three_phase: true +#     last_epoch: -1 +#     verbose: false + +#     # Non-class arguments +#     interval: step +#     monitor: val/loss  optimizers:    network:      _target_: madgrad.MADGRAD -    lr: 3.0e-4 +    lr: 1.0e-4      momentum: 0.9      weight_decay: 0 -    eps: 1.0e-6 +    eps: 1.0e-7      parameters: network  trainer: -  max_epochs: 256 +  max_epochs: 128 +  limit_train_batches: 0.01 +  limit_val_batches: 0.1  datamodule:    batch_size: 8 +  # resize: [288, 320]  summary: null diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml index c190151..0ec3b8a 100644 --- a/training/conf/model/lit_transformer.yaml +++ b/training/conf/model/lit_transformer.yaml @@ -1,6 +1,4 @@  _target_: text_recognizer.models.transformer.TransformerLitModel -interval: step -monitor: val/loss  max_output_len: 451  start_token: <s>  end_token: <e> diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml index 632668b..6dc44d7 100644 --- a/training/conf/model/lit_vqvae.yaml +++ b/training/conf/model/lit_vqvae.yaml @@ -1,2 +1 @@  _target_: text_recognizer.models.vqvae.VQVAELitModel -latent_loss_weight: 0.25 diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index d3a3b0f..1d61129 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -5,6 +5,7 @@ defaults:  _target_: text_recognizer.networks.conv_transformer.ConvTransformer  input_dims: [1, 576, 640]  hidden_dim: 128 +encoder_dim: 1280  dropout_rate: 0.2  num_classes: 1006  pad_index: 1002 diff --git a/training/conf/network/decoder/vae_decoder.yaml b/training/conf/network/decoder/vae_decoder.yaml index 60cdcf1..2053544 100644 --- a/training/conf/network/decoder/vae_decoder.yaml +++ b/training/conf/network/decoder/vae_decoder.yaml @@ -1,5 +1,6 @@  _target_: text_recognizer.networks.vqvae.decoder.Decoder  out_channels: 1  -hidden_dim: 64 -channels_multipliers: [8, 4, 2, 1] +hidden_dim: 32 +channels_multipliers: [4, 2, 1]  dropout_rate: 0.0 +activation: leaky_relu diff --git a/training/conf/network/encoder/vae_encoder.yaml b/training/conf/network/encoder/vae_encoder.yaml index 73529fc..2ea3adf 100644 --- a/training/conf/network/encoder/vae_encoder.yaml +++ b/training/conf/network/encoder/vae_encoder.yaml @@ -1,5 +1,6 @@  _target_: text_recognizer.networks.vqvae.encoder.Encoder  in_channels: 1  -hidden_dim: 64 -channels_multipliers: [1, 2, 4, 8] +hidden_dim: 32 +channels_multipliers: [1, 2, 4]  dropout_rate: 0.0 +activation: leaky_relu diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml index 70d27d7..d97e9b6 100644 --- a/training/conf/network/vqvae.yaml +++ b/training/conf/network/vqvae.yaml @@ -3,7 +3,7 @@ defaults:    - decoder: vae_decoder  _target_: text_recognizer.networks.vqvae.vqvae.VQVAE -hidden_dim: 512 -embedding_dim: 64 -num_embeddings: 4096 +hidden_dim: 128 +embedding_dim: 32 +num_embeddings: 1024  decay: 0.99 diff --git a/training/conf/optimizers/madgrad.yaml b/training/conf/optimizers/madgrad.yaml index a6c059d..d97bf7e 100644 --- a/training/conf/optimizers/madgrad.yaml +++ b/training/conf/optimizers/madgrad.yaml @@ -1,6 +1,6 @@  madgrad:    _target_: madgrad.MADGRAD -  lr: 1.0e-3 +  lr: 3.0e-4    momentum: 0.9    weight_decay: 0    eps: 1.0e-6  |