diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 23:24:20 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 23:24:20 +0200 |
commit | dedf8deb025ac9efdad5e9baf9165ef63d6829ff (patch) | |
tree | 56b10fcaef479d8abe9b0e6c05e07ad5e02b9ab0 /training | |
parent | 532286b516b17d279c321358bf03dddc8adc8029 (diff) |
Pre-commit fixes, optimizer loading fix
Diffstat (limited to 'training')
-rw-r--r-- | training/experiments/image_transformer.yaml | 12 | ||||
-rw-r--r-- | training/run_experiment.py | 4 |
2 files changed, 9 insertions, 7 deletions
diff --git a/training/experiments/image_transformer.yaml b/training/experiments/image_transformer.yaml index 7f0bbb7..012a19b 100644 --- a/training/experiments/image_transformer.yaml +++ b/training/experiments/image_transformer.yaml @@ -1,6 +1,6 @@ network: type: ImageTransformer - args: + args: input_shape: None output_shape: None encoder: @@ -17,20 +17,20 @@ network: model: type: LitTransformerModel args: - optimizer: + optimizer: type: MADGRAD args: lr: 1.0e-2 momentum: 0.9 weight_decay: 0 eps: 1.0e-6 - lr_scheduler: + lr_scheduler: type: CosineAnnealingLR - args: + args: T_max: 512 criterion: type: CrossEntropyLoss - args: + args: weight: None ignore_index: -100 reduction: mean @@ -40,7 +40,7 @@ model: data: type: IAMExtendedParagraphs - args: + args: batch_size: 16 num_workers: 12 train_fraction: 0.8 diff --git a/training/run_experiment.py b/training/run_experiment.py index 8a29555..0a67bfa 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -50,7 +50,9 @@ def _import_class(module_and_class_name: str) -> type: return getattr(module, class_name) -def _configure_pl_callbacks(args: List[Union[OmegaConf, NamedTuple]]) -> List[Type[pl.callbacks.Callback]]: +def _configure_pl_callbacks( + args: List[Union[OmegaConf, NamedTuple]] +) -> List[Type[pl.callbacks.Callback]]: """Configures PyTorch Lightning callbacks.""" pl_callbacks = [ getattr(pl.callbacks, callback.type)(**callback.args) for callback in args |