summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-06 22:52:59 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-06 22:52:59 +0200
commit33190bc9c0c377edab280efe4b0bd0e53bb6cb00 (patch)
treed10b36e4a15e6c9c9df9fd549f4870cb6bcd37bd
parent31d58f2108165802d26eb1c1bdb9e5f052b4dd26 (diff)
Refactor train script
-rw-r--r--training/run_experiment.py36
1 files changed, 20 insertions, 16 deletions
diff --git a/training/run_experiment.py b/training/run_experiment.py
index ea9f512..289866e 100644
--- a/training/run_experiment.py
+++ b/training/run_experiment.py
@@ -52,7 +52,7 @@ def _import_class(module_and_class_name: str) -> type:
def _configure_callbacks(
args: List[Union[OmegaConf, NamedTuple]]
) -> List[Type[pl.callbacks.Callback]]:
- """Configures PyTorch Lightning callbacks."""
+ """Configures lightning callbacks."""
pl_callbacks = [
getattr(pl.callbacks, callback.type)(**callback.args) for callback in args
]
@@ -62,7 +62,7 @@ def _configure_callbacks(
def _configure_logger(
network: Type[nn.Module], args: Dict, use_wandb: bool
) -> pl.loggers.WandbLogger:
- """Configures PyTorch Lightning logger."""
+ """Configures lightning logger."""
if use_wandb:
pl_logger = pl.loggers.WandbLogger()
pl_logger.watch(network)
@@ -89,11 +89,23 @@ def _save_best_weights(
wandb.save(best_model_path)
+def _load_lit_model(lit_model_class: type, network: Type[nn.Module], config: OmegaConf) -> Type[pl.LightningModule]:
+ """Load lightning model."""
+ if config.load_checkpoint is not None:
+ logger.info(
+ f"Loading network weights from checkpoint: {config.load_checkpoint}"
+ )
+ return lit_model_class.load_from_checkpoint(
+ config.load_checkpoint, network=network, **config.model.args
+ )
+ return lit_model_class(network=network, **config.model.args)
+
+
def run(path: str, train: bool, test: bool, tune: bool, use_wandb: bool) -> None:
"""Runs experiment."""
logger.info("Starting experiment...")
- # Seed everything in the experiment
+ # Seed everything in the experiment.
logger.info(f"Seeding everthing with seed={SEED}")
pl.utilities.seed.seed_everything(SEED)
@@ -101,7 +113,7 @@ def run(path: str, train: bool, test: bool, tune: bool, use_wandb: bool) -> None
logger.info(f"Loading config from: {path}")
config = OmegaConf.load(path)
- # Load classes
+ # Load classes.
data_module_class = _import_class(f"text_recognizer.data.{config.data.type}")
network_class = _import_class(f"text_recognizer.networks.{config.network.type}")
lit_model_class = _import_class(f"text_recognizer.models.{config.model.type}")
@@ -110,23 +122,15 @@ def run(path: str, train: bool, test: bool, tune: bool, use_wandb: bool) -> None
data_module = data_module_class(**config.data.args)
network = network_class(**data_module.config(), **config.network.args)
- # Load callback and logger
+ # Load callback and logger.
callbacks = _configure_callbacks(config.callbacks)
pl_logger = _configure_logger(network, config, use_wandb)
- # Checkpoint
- if config.load_checkpoint is not None:
- logger.info(
- f"Loading network weights from checkpoint: {config.load_checkpoint}"
- )
- lit_model = lit_model_class.load_from_checkpoint(
- config.load_checkpoint, network=network, **config.model.args
- )
- else:
- lit_model = lit_model_class(network=network, **config.model.args)
+ # Load ligtning model.
+ lit_model = _load_lit_model(lit_model_class, network, config)
trainer = pl.Trainer(
- **config.trainer,
+ **config.trainer.args,
callbacks=callbacks,
logger=pl_logger,
weigths_save_path="training/logs",