summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/conf/callbacks/checkpoint.yaml5
-rw-r--r--training/conf/callbacks/early_stopping.yaml5
-rw-r--r--training/conf/callbacks/learning_rate_monitor.yaml5
-rw-r--r--training/conf/callbacks/swa.yaml5
-rw-r--r--training/conf/config.yaml6
-rw-r--r--training/conf/criterion/mse.yaml3
-rw-r--r--training/conf/dataset/iam_extended_paragraphs.yaml9
-rw-r--r--training/conf/lr_scheduler/one_cycle.yaml8
-rw-r--r--training/conf/model/lit_vqvae.yaml23
-rw-r--r--training/conf/network/vqvae.yaml23
-rw-r--r--training/conf/optimizer/madgrad.yaml6
-rw-r--r--training/conf/trainer/default.yaml23
-rw-r--r--training/run_experiment.py54
13 files changed, 97 insertions, 78 deletions
diff --git a/training/conf/callbacks/checkpoint.yaml b/training/conf/callbacks/checkpoint.yaml
index afc536f..f3beb1b 100644
--- a/training/conf/callbacks/checkpoint.yaml
+++ b/training/conf/callbacks/checkpoint.yaml
@@ -1,5 +1,6 @@
-type: ModelCheckpoint
-args:
+checkpoint:
+ type: ModelCheckpoint
+ args:
monitor: val_loss
mode: min
save_last: true
diff --git a/training/conf/callbacks/early_stopping.yaml b/training/conf/callbacks/early_stopping.yaml
index caab824..ec671fd 100644
--- a/training/conf/callbacks/early_stopping.yaml
+++ b/training/conf/callbacks/early_stopping.yaml
@@ -1,5 +1,6 @@
-type: EarlyStopping
-args:
+early_stopping:
+ type: EarlyStopping
+ args:
monitor: val_loss
mode: min
patience: 10
diff --git a/training/conf/callbacks/learning_rate_monitor.yaml b/training/conf/callbacks/learning_rate_monitor.yaml
index 003ab7a..11a5ecf 100644
--- a/training/conf/callbacks/learning_rate_monitor.yaml
+++ b/training/conf/callbacks/learning_rate_monitor.yaml
@@ -1,3 +1,4 @@
-type: LearningRateMonitor
-args:
+learning_rate_monitor:
+ type: LearningRateMonitor
+ args:
logging_interval: step
diff --git a/training/conf/callbacks/swa.yaml b/training/conf/callbacks/swa.yaml
index 279ca69..92d9e6b 100644
--- a/training/conf/callbacks/swa.yaml
+++ b/training/conf/callbacks/swa.yaml
@@ -1,5 +1,6 @@
-type: StochasticWeightAveraging
-args:
+stochastic_weight_averaging:
+ type: StochasticWeightAveraging
+ args:
swa_epoch_start: 0.8
swa_lrs: 0.05
annealing_epochs: 10
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index c413a1a..b43e375 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -1,8 +1,14 @@
defaults:
- network: vqvae
+ - criterion: mse
+ - optimizer: madgrad
+ - lr_scheduler: one_cycle
- model: lit_vqvae
- dataset: iam_extended_paragraphs
- trainer: default
- callbacks:
- checkpoint
- learning_rate_monitor
+
+load_checkpoint: null
+logging: INFO
diff --git a/training/conf/criterion/mse.yaml b/training/conf/criterion/mse.yaml
new file mode 100644
index 0000000..4d89cbc
--- /dev/null
+++ b/training/conf/criterion/mse.yaml
@@ -0,0 +1,3 @@
+type: MSELoss
+args:
+ reduction: mean
diff --git a/training/conf/dataset/iam_extended_paragraphs.yaml b/training/conf/dataset/iam_extended_paragraphs.yaml
index 6bd7fc9..6439a15 100644
--- a/training/conf/dataset/iam_extended_paragraphs.yaml
+++ b/training/conf/dataset/iam_extended_paragraphs.yaml
@@ -1,7 +1,6 @@
-# @package _group_
type: IAMExtendedParagraphs
args:
- batch_size: 32
- num_workers: 12
- train_fraction: 0.8
- augment: true
+ batch_size: 32
+ num_workers: 12
+ train_fraction: 0.8
+ augment: true
diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml
new file mode 100644
index 0000000..60a6f27
--- /dev/null
+++ b/training/conf/lr_scheduler/one_cycle.yaml
@@ -0,0 +1,8 @@
+type: OneCycleLR
+args:
+ interval: step
+ max_lr: 1.0e-3
+ three_phase: true
+ epochs: 64
+ steps_per_epoch: 633 # num_samples / batch_size
+monitor: val_loss
diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml
index 90780b7..7136dbd 100644
--- a/training/conf/model/lit_vqvae.yaml
+++ b/training/conf/model/lit_vqvae.yaml
@@ -1,24 +1,3 @@
-# @package _group_
type: LitVQVAEModel
args:
- optimizer:
- type: MADGRAD
- args:
- lr: 1.0e-3
- momentum: 0.9
- weight_decay: 0
- eps: 1.0e-6
- lr_scheduler:
- type: OneCycleLR
- args:
- interval: step
- max_lr: 1.0e-3
- three_phase: true
- epochs: 64
- steps_per_epoch: 633 # num_samples / batch_size
- criterion:
- type: MSELoss
- args:
- reduction: mean
- monitor: val_loss
- mapping: sentence_piece
+ mapping: sentence_piece
diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml
index 288d2aa..22eebf8 100644
--- a/training/conf/network/vqvae.yaml
+++ b/training/conf/network/vqvae.yaml
@@ -1,14 +1,13 @@
-# @package _group_
type: VQVAE
args:
- in_channels: 1
- channels: [64, 96]
- kernel_sizes: [4, 4]
- strides: [2, 2]
- num_residual_layers: 2
- embedding_dim: 64
- num_embeddings: 256
- upsampling: null
- beta: 0.25
- activation: leaky_relu
- dropout_rate: 0.2
+ in_channels: 1
+ channels: [64, 96]
+ kernel_sizes: [4, 4]
+ strides: [2, 2]
+ num_residual_layers: 2
+ embedding_dim: 64
+ num_embeddings: 256
+ upsampling: null
+ beta: 0.25
+ activation: leaky_relu
+ dropout_rate: 0.2
diff --git a/training/conf/optimizer/madgrad.yaml b/training/conf/optimizer/madgrad.yaml
new file mode 100644
index 0000000..2f2cff9
--- /dev/null
+++ b/training/conf/optimizer/madgrad.yaml
@@ -0,0 +1,6 @@
+type: MADGRAD
+args:
+ lr: 1.0e-3
+ momentum: 0.9
+ weight_decay: 0
+ eps: 1.0e-6
diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml
index 3a88c6a..5797741 100644
--- a/training/conf/trainer/default.yaml
+++ b/training/conf/trainer/default.yaml
@@ -1,19 +1,16 @@
-# @package _group_
seed: 4711
-load_checkpoint: null
wandb: false
tune: false
train: true
test: true
-logging: INFO
args:
- stochastic_weight_avg: false
- auto_scale_batch_size: binsearch
- auto_lr_find: false
- gradient_clip_val: 0
- fast_dev_run: false
- gpus: 1
- precision: 16
- max_epochs: 64
- terminate_on_nan: true
- weights_summary: top
+ stochastic_weight_avg: false
+ auto_scale_batch_size: binsearch
+ auto_lr_find: false
+ gradient_clip_val: 0
+ fast_dev_run: false
+ gpus: 1
+ precision: 16
+ max_epochs: 64
+ terminate_on_nan: true
+ weights_summary: top
diff --git a/training/run_experiment.py b/training/run_experiment.py
index def1e77..b3c9552 100644
--- a/training/run_experiment.py
+++ b/training/run_experiment.py
@@ -3,6 +3,9 @@ from datetime import datetime
import importlib
from pathlib import Path
from typing import List, Optional, Type
+import warnings
+
+warnings.filterwarnings("ignore")
import hydra
from loguru import logger
@@ -29,7 +32,7 @@ def _create_experiment_dir(config: DictConfig) -> Path:
def _save_config(config: DictConfig, log_dir: Path) -> None:
"""Saves config to log directory."""
- with (log_dir / "config.yaml").open("r") as f:
+ with (log_dir / "config.yaml").open("w") as f:
OmegaConf.save(config=config, f=f)
@@ -52,12 +55,11 @@ def _import_class(module_and_class_name: str) -> type:
return getattr(module, class_name)
-def _configure_callbacks(
- callbacks: List[DictConfig],
-) -> List[Type[pl.callbacks.Callback]]:
+def _configure_callbacks(callbacks: DictConfig,) -> List[Type[pl.callbacks.Callback]]:
"""Configures lightning callbacks."""
pl_callbacks = [
- getattr(pl.callbacks, callback.type)(**callback.args) for callback in callbacks
+ getattr(pl.callbacks, callback.type)(**callback.args)
+ for callback in callbacks.values()
]
return pl_callbacks
@@ -77,12 +79,12 @@ def _configure_logger(
def _save_best_weights(
- callbacks: List[Type[pl.callbacks.Callback]], use_wandb: bool
+ pl_callbacks: List[Type[pl.callbacks.Callback]], use_wandb: bool
) -> None:
"""Saves the best model."""
model_checkpoint_callback = next(
callback
- for callback in callbacks
+ for callback in pl_callbacks
if isinstance(callback, pl.callbacks.ModelCheckpoint)
)
best_model_path = model_checkpoint_callback.best_model_path
@@ -97,20 +99,31 @@ def _load_lit_model(
lit_model_class: type, network: Type[nn.Module], config: DictConfig
) -> Type[pl.LightningModule]:
"""Load lightning model."""
- if config.trainer.load_checkpoint is not None:
+ if config.load_checkpoint is not None:
logger.info(
- f"Loading network weights from checkpoint: {config.trainer.load_checkpoint}"
+ f"Loading network weights from checkpoint: {config.load_checkpoint}"
)
return lit_model_class.load_from_checkpoint(
- config.trainer.load_checkpoint, network=network, **config.model.args
+ config.load_checkpoint,
+ network=network,
+ optimizer=config.optimizer,
+ criterion=config.criterion,
+ lr_scheduler=config.lr_scheduler,
+ **config.model.args,
)
- return lit_model_class(network=network, **config.model.args)
+ return lit_model_class(
+ network=network,
+ optimizer=config.optimizer,
+ criterion=config.criterion,
+ lr_scheduler=config.lr_scheduler,
+ **config.model.args,
+ )
def run(config: DictConfig) -> None:
"""Runs experiment."""
log_dir = _create_experiment_dir(config)
- _configure_logging(log_dir, level=config.trainer.logging)
+ _configure_logging(log_dir, level=config.logging)
logger.info("Starting experiment...")
pl.utilities.seed.seed_everything(config.trainer.seed)
@@ -125,7 +138,7 @@ def run(config: DictConfig) -> None:
network = network_class(**data_module.config(), **config.network.args)
# Load callback and logger.
- callbacks = _configure_callbacks(config.callbacks)
+ pl_callbacks = _configure_callbacks(config.callbacks)
pl_logger = _configure_logger(network, config, log_dir)
# Load ligtning model.
@@ -136,12 +149,17 @@ def run(config: DictConfig) -> None:
trainer = pl.Trainer(
**config.trainer.args,
- callbacks=callbacks,
+ callbacks=pl_callbacks,
logger=pl_logger,
weights_save_path=str(log_dir),
)
- if config.trainer.tune and not config.trainer.args.fast_dev_run:
+ if config.trainer.args.fast_dev_run:
+ logger.info("Fast development run...")
+ trainer.fit(lit_model, datamodule=data_module)
+ return None
+
+ if config.trainer.tune:
logger.info("Tuning learning rate and batch size...")
trainer.tune(lit_model, datamodule=data_module)
@@ -149,17 +167,17 @@ def run(config: DictConfig) -> None:
logger.info("Training network...")
trainer.fit(lit_model, datamodule=data_module)
- if config.trainer.test and not config.trainer.args.fast_dev_run:
+ if config.trainer.test:
logger.info("Testing network...")
trainer.test(lit_model, datamodule=data_module)
- if not config.trainer.args.fast_dev_run:
- _save_best_weights(callbacks, config.trainer.wandb)
+ _save_best_weights(pl_callbacks, config.trainer.wandb)
@hydra.main(config_path="conf", config_name="config")
def main(config: DictConfig) -> None:
"""Loads config with hydra."""
+ print(OmegaConf.to_yaml(config))
run(config)