summaryrefslogtreecommitdiff
path: root/training/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/utils.py')
-rw-r--r--training/utils.py23
1 files changed, 9 insertions, 14 deletions
diff --git a/training/utils.py b/training/utils.py
index ef74f61..d23396e 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -17,6 +17,10 @@ from tqdm import tqdm
import wandb
+def print_config(config: DictConfig) -> None:
+ print(OmegaConf.to_yaml(config))
+
+
@rank_zero_only
def configure_logging(config: DictConfig) -> None:
"""Configure the loguru logger for output to terminal and disk."""
@@ -30,7 +34,7 @@ def configure_callbacks(config: DictConfig,) -> List[Type[Callback]]:
callbacks = []
if config.get("callbacks"):
for callback_config in config.callbacks.values():
- if config.get("_target_"):
+ if callback_config.get("_target_"):
log.info(f"Instantiating callback <{callback_config._target_}>")
callbacks.append(hydra.utils.instantiate(callback_config))
return callbacks
@@ -41,8 +45,8 @@ def configure_logger(config: DictConfig) -> List[Type[LightningLoggerBase]]:
logger = []
if config.get("logger"):
for logger_config in config.logger.values():
- if config.get("_target_"):
- log.info(f"Instantiating callback <{logger_config._target_}>")
+ if logger_config.get("_target_"):
+ log.info(f"Instantiating logger <{logger_config._target_}>")
logger.append(hydra.utils.instantiate(logger_config))
return logger
@@ -67,17 +71,8 @@ def extras(config: DictConfig) -> None:
# Debuggers do not like GPUs and multiprocessing.
if config.trainer.get("gpus"):
config.trainer.gpus = 0
- if config.datamodule.get("pin_memory"):
- config.datamodule.pin_memory = False
- if config.datamodule.get("num_workers"):
- config.datamodule.num_workers = 0
-
- # Force multi-gpu friendly config.
- accelerator = config.trainer.get("accelerator")
- if accelerator in ["ddp", "ddp_spawn", "dp", "ddp2"]:
- log.info(
- f"Forcing ddp friendly configuration! <config.trainer.accelerator={accelerator}>"
- )
+ if config.trainer.get("precision"):
+ config.trainer.precision = 32
if config.datamodule.get("pin_memory"):
config.datamodule.pin_memory = False
if config.datamodule.get("num_workers"):