From 33190bc9c0c377edab280efe4b0bd0e53bb6cb00 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 6 Apr 2021 22:52:59 +0200 Subject: Refactor train script --- training/run_experiment.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/training/run_experiment.py b/training/run_experiment.py index ea9f512..289866e 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -52,7 +52,7 @@ def _import_class(module_and_class_name: str) -> type: def _configure_callbacks( args: List[Union[OmegaConf, NamedTuple]] ) -> List[Type[pl.callbacks.Callback]]: - """Configures PyTorch Lightning callbacks.""" + """Configures lightning callbacks.""" pl_callbacks = [ getattr(pl.callbacks, callback.type)(**callback.args) for callback in args ] @@ -62,7 +62,7 @@ def _configure_callbacks( def _configure_logger( network: Type[nn.Module], args: Dict, use_wandb: bool ) -> pl.loggers.WandbLogger: - """Configures PyTorch Lightning logger.""" + """Configures lightning logger.""" if use_wandb: pl_logger = pl.loggers.WandbLogger() pl_logger.watch(network) @@ -89,11 +89,23 @@ def _save_best_weights( wandb.save(best_model_path) +def _load_lit_model(lit_model_class: type, network: Type[nn.Module], config: OmegaConf) -> Type[pl.LightningModule]: + """Load lightning model.""" + if config.load_checkpoint is not None: + logger.info( + f"Loading network weights from checkpoint: {config.load_checkpoint}" + ) + return lit_model_class.load_from_checkpoint( + config.load_checkpoint, network=network, **config.model.args + ) + return lit_model_class(network=network, **config.model.args) + + def run(path: str, train: bool, test: bool, tune: bool, use_wandb: bool) -> None: """Runs experiment.""" logger.info("Starting experiment...") - # Seed everything in the experiment + # Seed everything in the experiment. logger.info(f"Seeding everthing with seed={SEED}") pl.utilities.seed.seed_everything(SEED) @@ -101,7 +113,7 @@ def run(path: str, train: bool, test: bool, tune: bool, use_wandb: bool) -> None logger.info(f"Loading config from: {path}") config = OmegaConf.load(path) - # Load classes + # Load classes. data_module_class = _import_class(f"text_recognizer.data.{config.data.type}") network_class = _import_class(f"text_recognizer.networks.{config.network.type}") lit_model_class = _import_class(f"text_recognizer.models.{config.model.type}") @@ -110,23 +122,15 @@ def run(path: str, train: bool, test: bool, tune: bool, use_wandb: bool) -> None data_module = data_module_class(**config.data.args) network = network_class(**data_module.config(), **config.network.args) - # Load callback and logger + # Load callback and logger. callbacks = _configure_callbacks(config.callbacks) pl_logger = _configure_logger(network, config, use_wandb) - # Checkpoint - if config.load_checkpoint is not None: - logger.info( - f"Loading network weights from checkpoint: {config.load_checkpoint}" - ) - lit_model = lit_model_class.load_from_checkpoint( - config.load_checkpoint, network=network, **config.model.args - ) - else: - lit_model = lit_model_class(network=network, **config.model.args) + # Load ligtning model. + lit_model = _load_lit_model(lit_model_class, network, config) trainer = pl.Trainer( - **config.trainer, + **config.trainer.args, callbacks=callbacks, logger=pl_logger, weigths_save_path="training/logs", -- cgit v1.2.3-70-g09d2