summaryrefslogtreecommitdiff
path: root/training/run.py
blob: 288a1ef8fbdce64115d757b3900b09579828a2d8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""Script to run experiments."""
from typing import Callable, List, Optional, Type

import hydra
from loguru import logger as log
from omegaconf import DictConfig
from pytorch_lightning import (
    Callback,
    LightningDataModule,
    LightningModule,
    seed_everything,
    Trainer,
)
from pytorch_lightning.loggers import LightningLoggerBase
from torch import nn
from torchinfo import summary
import utils


def run(config: DictConfig) -> Optional[float]:
    """Runs experiment."""
    utils.configure_logging(config)
    log.info("Starting experiment...")

    if config.get("seed"):
        seed_everything(config.seed, workers=True)

    log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
    datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)

    log.info(f"Instantiating network <{config.network._target_}>")
    network: nn.Module = hydra.utils.instantiate(config.network)

    log.info(f"Instantiating criterion <{config.criterion._target_}>")
    loss_fn: Type[nn.Module] = hydra.utils.instantiate(config.criterion)

    log.info(f"Instantiating decoder <{config.criterion._target_}>")
    decoder: Type[Callable] = hydra.utils.instantiate(
        config.decoder,
        network=network,
        tokenizer=datamodule.tokenizer,
    )

    log.info(f"Instantiating model <{config.model._target_}>")
    model: LightningModule = hydra.utils.instantiate(
        config.model,
        network=network,
        tokenizer=datamodule.tokenizer,
        decoder=decoder,
        loss_fn=loss_fn,
        optimizer_config=config.optimizer,
        lr_scheduler_config=config.lr_scheduler,
        _recursive_=False,
    )

    # Load callback and logger.
    callbacks: List[Type[Callback]] = utils.configure_callbacks(config)
    logger: List[Type[LightningLoggerBase]] = utils.configure_logger(config)

    log.info(f"Instantiating trainer <{config.trainer._target_}>")
    trainer: Trainer = hydra.utils.instantiate(
        config.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
    )

    # Log hyperparameters
    log.info("Logging hyperparameters")
    utils.log_hyperparameters(config=config, model=model, trainer=trainer)
    utils.save_config(config)

    if config.get("summary"):
        summary(
            network, list(map(lambda x: list(x), config.summary)), depth=1, device="cpu"
        )

    if config.debug:
        log.info("Fast development run...")
        trainer.fit(model, datamodule=datamodule)
        return None

    if config.tune:
        log.info("Tuning hyperparameters...")
        trainer.tune(model, datamodule=datamodule)

    if config.train:
        log.info("Training network...")
        trainer.fit(model, datamodule=datamodule)

    if config.test:
        log.info("Testing network...")
        ckpt_path = trainer.checkpoint_callback.best_model_path
        if ckpt_path is None:
            log.error("No best checkpoint path for model found")
            return
        trainer.test(model, datamodule=datamodule, ckpt_path=ckpt_path)

    log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}")
    utils.finish(logger)