diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-03 18:18:48 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-03 18:18:48 +0200 |
commit | bd4bd443f339e95007bfdabf3e060db720f4d4b9 (patch) | |
tree | e55cb3744904f7c2a0348b100c7e92a65e538a16 /training/utils.py | |
parent | 75801019981492eedf9280cb352eea3d8e99b65f (diff) |
Training working, multiple bug fixes
Diffstat (limited to 'training/utils.py')
-rw-r--r-- | training/utils.py | 23 |
1 files changed, 9 insertions, 14 deletions
diff --git a/training/utils.py b/training/utils.py index ef74f61..d23396e 100644 --- a/training/utils.py +++ b/training/utils.py @@ -17,6 +17,10 @@ from tqdm import tqdm import wandb +def print_config(config: DictConfig) -> None: + print(OmegaConf.to_yaml(config)) + + @rank_zero_only def configure_logging(config: DictConfig) -> None: """Configure the loguru logger for output to terminal and disk.""" @@ -30,7 +34,7 @@ def configure_callbacks(config: DictConfig,) -> List[Type[Callback]]: callbacks = [] if config.get("callbacks"): for callback_config in config.callbacks.values(): - if config.get("_target_"): + if callback_config.get("_target_"): log.info(f"Instantiating callback <{callback_config._target_}>") callbacks.append(hydra.utils.instantiate(callback_config)) return callbacks @@ -41,8 +45,8 @@ def configure_logger(config: DictConfig) -> List[Type[LightningLoggerBase]]: logger = [] if config.get("logger"): for logger_config in config.logger.values(): - if config.get("_target_"): - log.info(f"Instantiating callback <{logger_config._target_}>") + if logger_config.get("_target_"): + log.info(f"Instantiating logger <{logger_config._target_}>") logger.append(hydra.utils.instantiate(logger_config)) return logger @@ -67,17 +71,8 @@ def extras(config: DictConfig) -> None: # Debuggers do not like GPUs and multiprocessing. if config.trainer.get("gpus"): config.trainer.gpus = 0 - if config.datamodule.get("pin_memory"): - config.datamodule.pin_memory = False - if config.datamodule.get("num_workers"): - config.datamodule.num_workers = 0 - - # Force multi-gpu friendly config. - accelerator = config.trainer.get("accelerator") - if accelerator in ["ddp", "ddp_spawn", "dp", "ddp2"]: - log.info( - f"Forcing ddp friendly configuration! <config.trainer.accelerator={accelerator}>" - ) + if config.trainer.get("precision"): + config.trainer.precision = 32 if config.datamodule.get("pin_memory"): config.datamodule.pin_memory = False if config.datamodule.get("num_workers"): |