summaryrefslogtreecommitdiff
path: root/training/run_experiment.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/run_experiment.py')
-rw-r--r--training/run_experiment.py54
1 files changed, 36 insertions, 18 deletions
diff --git a/training/run_experiment.py b/training/run_experiment.py
index def1e77..b3c9552 100644
--- a/training/run_experiment.py
+++ b/training/run_experiment.py
@@ -3,6 +3,9 @@ from datetime import datetime
import importlib
from pathlib import Path
from typing import List, Optional, Type
+import warnings
+
+warnings.filterwarnings("ignore")
import hydra
from loguru import logger
@@ -29,7 +32,7 @@ def _create_experiment_dir(config: DictConfig) -> Path:
def _save_config(config: DictConfig, log_dir: Path) -> None:
"""Saves config to log directory."""
- with (log_dir / "config.yaml").open("r") as f:
+ with (log_dir / "config.yaml").open("w") as f:
OmegaConf.save(config=config, f=f)
@@ -52,12 +55,11 @@ def _import_class(module_and_class_name: str) -> type:
return getattr(module, class_name)
-def _configure_callbacks(
- callbacks: List[DictConfig],
-) -> List[Type[pl.callbacks.Callback]]:
+def _configure_callbacks(callbacks: DictConfig,) -> List[Type[pl.callbacks.Callback]]:
"""Configures lightning callbacks."""
pl_callbacks = [
- getattr(pl.callbacks, callback.type)(**callback.args) for callback in callbacks
+ getattr(pl.callbacks, callback.type)(**callback.args)
+ for callback in callbacks.values()
]
return pl_callbacks
@@ -77,12 +79,12 @@ def _configure_logger(
def _save_best_weights(
- callbacks: List[Type[pl.callbacks.Callback]], use_wandb: bool
+ pl_callbacks: List[Type[pl.callbacks.Callback]], use_wandb: bool
) -> None:
"""Saves the best model."""
model_checkpoint_callback = next(
callback
- for callback in callbacks
+ for callback in pl_callbacks
if isinstance(callback, pl.callbacks.ModelCheckpoint)
)
best_model_path = model_checkpoint_callback.best_model_path
@@ -97,20 +99,31 @@ def _load_lit_model(
lit_model_class: type, network: Type[nn.Module], config: DictConfig
) -> Type[pl.LightningModule]:
"""Load lightning model."""
- if config.trainer.load_checkpoint is not None:
+ if config.load_checkpoint is not None:
logger.info(
- f"Loading network weights from checkpoint: {config.trainer.load_checkpoint}"
+ f"Loading network weights from checkpoint: {config.load_checkpoint}"
)
return lit_model_class.load_from_checkpoint(
- config.trainer.load_checkpoint, network=network, **config.model.args
+ config.load_checkpoint,
+ network=network,
+ optimizer=config.optimizer,
+ criterion=config.criterion,
+ lr_scheduler=config.lr_scheduler,
+ **config.model.args,
)
- return lit_model_class(network=network, **config.model.args)
+ return lit_model_class(
+ network=network,
+ optimizer=config.optimizer,
+ criterion=config.criterion,
+ lr_scheduler=config.lr_scheduler,
+ **config.model.args,
+ )
def run(config: DictConfig) -> None:
"""Runs experiment."""
log_dir = _create_experiment_dir(config)
- _configure_logging(log_dir, level=config.trainer.logging)
+ _configure_logging(log_dir, level=config.logging)
logger.info("Starting experiment...")
pl.utilities.seed.seed_everything(config.trainer.seed)
@@ -125,7 +138,7 @@ def run(config: DictConfig) -> None:
network = network_class(**data_module.config(), **config.network.args)
# Load callback and logger.
- callbacks = _configure_callbacks(config.callbacks)
+ pl_callbacks = _configure_callbacks(config.callbacks)
pl_logger = _configure_logger(network, config, log_dir)
# Load ligtning model.
@@ -136,12 +149,17 @@ def run(config: DictConfig) -> None:
trainer = pl.Trainer(
**config.trainer.args,
- callbacks=callbacks,
+ callbacks=pl_callbacks,
logger=pl_logger,
weights_save_path=str(log_dir),
)
- if config.trainer.tune and not config.trainer.args.fast_dev_run:
+ if config.trainer.args.fast_dev_run:
+ logger.info("Fast development run...")
+ trainer.fit(lit_model, datamodule=data_module)
+ return None
+
+ if config.trainer.tune:
logger.info("Tuning learning rate and batch size...")
trainer.tune(lit_model, datamodule=data_module)
@@ -149,17 +167,17 @@ def run(config: DictConfig) -> None:
logger.info("Training network...")
trainer.fit(lit_model, datamodule=data_module)
- if config.trainer.test and not config.trainer.args.fast_dev_run:
+ if config.trainer.test:
logger.info("Testing network...")
trainer.test(lit_model, datamodule=data_module)
- if not config.trainer.args.fast_dev_run:
- _save_best_weights(callbacks, config.trainer.wandb)
+ _save_best_weights(pl_callbacks, config.trainer.wandb)
@hydra.main(config_path="conf", config_name="config")
def main(config: DictConfig) -> None:
"""Loads config with hydra."""
+ print(OmegaConf.to_yaml(config))
run(config)