1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
|
"""Script to run experiments."""
from typing import List, Optional, Type
import hydra
import loguru.logger as log
from omegaconf import DictConfig
from pytorch_lightning import (
Callback,
LightningDataModule,
LightningModule,
Trainer,
seed_everything,
)
from pytorch_lightning.loggers import LightningLoggerBase
from torch import nn
from utils import configure_logging
def configure_callbacks(
config: DictConfig,
) -> List[Type[Callback]]:
"""Configures lightning callbacks."""
callbacks = []
if config.get("callbacks"):
for callback_config in config.callbacks.values():
if config.get("_target_"):
log.info(f"Instantiating callback <{callback_config._target_}>")
callbacks.append(hydra.utils.instantiate(callback_config))
return callbacks
def configure_logger(config: DictConfig) -> List[Type[LightningLoggerBase]]:
logger = []
if config.get("logger"):
for logger_config in config.logger.values():
if config.get("_target_"):
log.info(f"Instantiating callback <{logger_config._target_}>")
logger.append(hydra.utils.instantiate(logger_config))
return logger
def run(config: DictConfig) -> Optional[float]:
"""Runs experiment."""
configure_logging(config.logging)
log.info("Starting experiment...")
if config.get("seed"):
seed_everything(config.seed)
log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
log.info(f"Instantiating network <{config.network._target_}>")
network: nn.Module = hydra.utils.instantiate(config.network, **datamodule.config())
log.info(f"Instantiating model <{config.model._target_}>")
model: LightningModule = hydra.utils.instantiate(
config.model,
network=network,
criterion=config.criterion,
optimizer=config.optimizer,
lr_scheduler=config.lr_scheduler,
_recursive_=False,
)
# Load callback and logger.
callbacks = configure_callbacks(config)
logger = configure_logger(config)
log.info(f"Instantiating trainer <{config.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
config.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
)
# Log hyperparameters
if config.debug:
log.info("Fast development run...")
trainer.fit(model, datamodule=datamodule)
return None
if config.tune:
log.info("Tuning learning rate and batch size...")
trainer.tune(model, datamodule=datamodule)
if config.train:
log.info("Training network...")
trainer.fit(model, datamodule=datamodule)
if config.test:
log.info("Testing network...")
trainer.test(model, datamodule=datamodule)
# Make sure everything closes properly
|