diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/experiments/image_transformer.yaml | 70 | ||||
-rw-r--r-- | training/run_experiment.py | 6 |
2 files changed, 73 insertions, 3 deletions
diff --git a/training/experiments/image_transformer.yaml b/training/experiments/image_transformer.yaml new file mode 100644 index 0000000..7f0bbb7 --- /dev/null +++ b/training/experiments/image_transformer.yaml @@ -0,0 +1,70 @@ +network: + type: ImageTransformer + args: + input_shape: None + output_shape: None + encoder: + type: None + args: None + mapping: sentence_piece + num_decoder_layers: 4 + hidden_dim: 256 + num_heads: 4 + expansion_dim: 1024 + dropout_rate: 0.1 + transformer_activation: glu + +model: + type: LitTransformerModel + args: + optimizer: + type: MADGRAD + args: + lr: 1.0e-2 + momentum: 0.9 + weight_decay: 0 + eps: 1.0e-6 + lr_scheduler: + type: CosineAnnealingLR + args: + T_max: 512 + criterion: + type: CrossEntropyLoss + args: + weight: None + ignore_index: -100 + reduction: mean + + monitor: val_loss + mapping: sentence_piece + +data: + type: IAMExtendedParagraphs + args: + batch_size: 16 + num_workers: 12 + train_fraction: 0.8 + augment: true + +callbacks: + - type: ModelCheckpoint + args: + monitor: val_loss + mode: min + - type: EarlyStopping + args: + monitor: val_loss + mode: min + patience: 10 + +trainer: + args: + stochastic_weight_avg: true + auto_scale_batch_size: power + gradient_clip_val: 0 + fast_dev_run: false + gpus: 1 + precision: 16 + max_epocs: 512 + terminate_on_nan: true + weights_summary: true diff --git a/training/run_experiment.py b/training/run_experiment.py index ff8b886..8a29555 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -2,7 +2,7 @@ from datetime import datetime import importlib from pathlib import Path -from typing import Dict, List, Optional, Type +from typing import Dict, List, NamedTuple, Optional, Union, Type import click from loguru import logger @@ -50,10 +50,10 @@ def _import_class(module_and_class_name: str) -> type: return getattr(module, class_name) -def _configure_pl_callbacks(args: List[Dict]) -> List[Type[pl.callbacks.Callback]]: +def _configure_pl_callbacks(args: List[Union[OmegaConf, NamedTuple]]) -> List[Type[pl.callbacks.Callback]]: """Configures PyTorch Lightning callbacks.""" pl_callbacks = [ - getattr(pl.callbacks, callback["type"])(**callback["args"]) for callback in args + getattr(pl.callbacks, callback.type)(**callback.args) for callback in args ] return pl_callbacks |