diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-02 13:51:15 +0200 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-02 13:51:15 +0200 | 
| commit | 1d0977585f01c42e9f6280559a1a98037907a62e (patch) | |
| tree | 7e86dd71b163f3138ed2658cb52c44e805f21539 /training | |
| parent | 58ae7154aa945cfe5a46592cc1dfb28f0a4e51b3 (diff) | |
Implemented training script with hydra
Diffstat (limited to 'training')
| -rw-r--r-- | training/.gitignore | 1 | ||||
| -rw-r--r-- | training/conf/callbacks/default.yaml | 14 | ||||
| -rw-r--r-- | training/conf/callbacks/swa.yaml | 16 | ||||
| -rw-r--r-- | training/conf/cnn_transformer.yaml (renamed from training/configs/cnn_transformer.yaml) | 0 | ||||
| -rw-r--r-- | training/conf/config.yaml | 6 | ||||
| -rw-r--r-- | training/conf/dataset/iam_extended_paragraphs.yaml | 7 | ||||
| -rw-r--r-- | training/conf/model/lit_vqvae.yaml | 24 | ||||
| -rw-r--r-- | training/conf/network/vqvae.yaml | 14 | ||||
| -rw-r--r-- | training/conf/trainer/default.yaml | 18 | ||||
| -rw-r--r-- | training/configs/vqvae.yaml | 89 | ||||
| -rw-r--r-- | training/run_experiment.py | 136 | 
11 files changed, 138 insertions, 187 deletions
diff --git a/training/.gitignore b/training/.gitignore index 333c1e9..7d268ea 100644 --- a/training/.gitignore +++ b/training/.gitignore @@ -1 +1,2 @@  logs/ +outputs/ diff --git a/training/conf/callbacks/default.yaml b/training/conf/callbacks/default.yaml new file mode 100644 index 0000000..74dc30c --- /dev/null +++ b/training/conf/callbacks/default.yaml @@ -0,0 +1,14 @@ +# @package _group_ +- type: ModelCheckpoint +  args: +      monitor: val_loss +      mode: min +      save_last: true +- type: LearningRateMonitor +  args: +      logging_interval: step +# - type: EarlyStopping +#   args: +#       monitor: val_loss +#       mode: min +#       patience: 10 diff --git a/training/conf/callbacks/swa.yaml b/training/conf/callbacks/swa.yaml new file mode 100644 index 0000000..144ad6e --- /dev/null +++ b/training/conf/callbacks/swa.yaml @@ -0,0 +1,16 @@ +# @package _group_ +- type: ModelCheckpoint +  args: +      monitor: val_loss +      mode: min +      save_last: true +- type: StochasticWeightAveraging +  args: +      swa_epoch_start: 0.8 +      swa_lrs: 0.05 +      annealing_epochs: 10 +      annealing_strategy: cos +      device: null +- type: LearningRateMonitor +  args: +      logging_interval: step diff --git a/training/configs/cnn_transformer.yaml b/training/conf/cnn_transformer.yaml index a4f16df..a4f16df 100644 --- a/training/configs/cnn_transformer.yaml +++ b/training/conf/cnn_transformer.yaml diff --git a/training/conf/config.yaml b/training/conf/config.yaml new file mode 100644 index 0000000..11adeb7 --- /dev/null +++ b/training/conf/config.yaml @@ -0,0 +1,6 @@ +defaults: +    - network: vqvae +    - model: lit_vqvae +    - dataset: iam_extended_paragraphs +    - trainer: default +    - callbacks: default diff --git a/training/conf/dataset/iam_extended_paragraphs.yaml b/training/conf/dataset/iam_extended_paragraphs.yaml new file mode 100644 index 0000000..6bd7fc9 --- /dev/null +++ b/training/conf/dataset/iam_extended_paragraphs.yaml @@ -0,0 +1,7 @@ +# @package _group_ +type: IAMExtendedParagraphs +args: +    batch_size: 32 +    num_workers: 12 +    train_fraction: 0.8 +    augment: true diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml new file mode 100644 index 0000000..90780b7 --- /dev/null +++ b/training/conf/model/lit_vqvae.yaml @@ -0,0 +1,24 @@ +# @package _group_ +type: LitVQVAEModel +args: +    optimizer: +        type: MADGRAD +        args: +            lr: 1.0e-3 +            momentum: 0.9 +            weight_decay: 0 +            eps: 1.0e-6 +    lr_scheduler: +        type: OneCycleLR +        args: +            interval: step +            max_lr: 1.0e-3 +            three_phase: true +            epochs: 64 +            steps_per_epoch: 633 # num_samples / batch_size +    criterion: +        type: MSELoss +        args: +            reduction: mean +    monitor: val_loss +    mapping: sentence_piece diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml new file mode 100644 index 0000000..8c30bbd --- /dev/null +++ b/training/conf/network/vqvae.yaml @@ -0,0 +1,14 @@ +# @package _group_ +type: VQVAE +args: +    in_channels: 1 +    channels: [32, 64, 64] +    kernel_sizes: [4, 4, 4] +    strides: [2, 2, 2] +    num_residual_layers: 2  +    embedding_dim: 64 +    num_embeddings: 256 +    upsampling: null +    beta: 0.25 +    activation: leaky_relu +    dropout_rate: 0.2 diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml new file mode 100644 index 0000000..82afd93 --- /dev/null +++ b/training/conf/trainer/default.yaml @@ -0,0 +1,18 @@ +# @package _group_ +seed: 4711 +load_checkpoint: null  +wandb: false +tune: false +train: true +test: true +logging: INFO +args: +    stochastic_weight_avg: false +    auto_scale_batch_size: binsearch +    gradient_clip_val: 0 +    fast_dev_run: false +    gpus: 1 +    precision: 16 +    max_epochs: 64 +    terminate_on_nan: true +    weights_summary: top diff --git a/training/configs/vqvae.yaml b/training/configs/vqvae.yaml deleted file mode 100644 index 13d7c97..0000000 --- a/training/configs/vqvae.yaml +++ /dev/null @@ -1,89 +0,0 @@ -seed: 4711 - -network: -        desc: Configuration of the PyTorch neural network. -        type: VQVAE -        args: -            in_channels: 1 -            channels: [32, 64, 64, 96, 96] -            kernel_sizes: [4, 4, 4, 4, 4] -            strides: [2, 2, 2, 2, 2] -            num_residual_layers: 2  -            embedding_dim: 512 -            num_embeddings: 1024 -            upsampling: null -            beta: 0.25 -            activation: leaky_relu -            dropout_rate: 0.2 - -model: -        desc: Configuration of the PyTorch Lightning model. -        type: LitVQVAEModel -        args: -                optimizer: -                        type: MADGRAD -                        args: -                                lr: 1.0e-3 -                                momentum: 0.9 -                                weight_decay: 0 -                                eps: 1.0e-6 -                lr_scheduler: -                        type: OneCycleLR -                        args: -                                interval: &interval step -                                max_lr: 1.0e-3 -                                three_phase: true -                                epochs: 64 -                                steps_per_epoch: 633 # num_samples / batch_size -                criterion: -                        type: MSELoss -                        args: -                                reduction: mean -                monitor: val_loss -                mapping: sentence_piece - -data: -        desc: Configuration of the training/test data. -        type: IAMExtendedParagraphs -        args: -                batch_size: 32 -                num_workers: 12 -                train_fraction: 0.8 -                augment: true - -callbacks: -        - type: ModelCheckpoint -          args: -                  monitor: val_loss -                  mode: min -                  save_last: true -        - type: StochasticWeightAveraging -          args: -                  swa_epoch_start: 0.8 -                  swa_lrs: 0.05 -                  annealing_epochs: 10 -                  annealing_strategy: cos -                  device: null -        - type: LearningRateMonitor -          args: -                  logging_interval: *interval -        # - type: EarlyStopping -        #   args: -        #           monitor: val_loss -        #           mode: min -        #           patience: 10 - -trainer: -        desc: Configuration of the PyTorch Lightning Trainer. -        args: -                stochastic_weight_avg: true -                auto_scale_batch_size: binsearch -                gradient_clip_val: 0 -                fast_dev_run: false -                gpus: 1 -                precision: 16 -                max_epochs: 64 -                terminate_on_nan: true -                weights_summary: top - -load_checkpoint: null  diff --git a/training/run_experiment.py b/training/run_experiment.py index bdefbf0..2b3ecab 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -4,17 +4,15 @@ import importlib  from pathlib import Path  from typing import Dict, List, Optional, Type -import click +import hydra  from loguru import logger -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig  import pytorch_lightning as pl  from torch import nn  from tqdm import tqdm  import wandb -SEED = 4711 -CONFIGS_DIRNAME = Path(__file__).parent.resolve() / "configs"  LOGS_DIRNAME = Path(__file__).parent.resolve() / "logs" @@ -29,21 +27,10 @@ def _create_experiment_dir(config: DictConfig) -> Path:      return log_dir -def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None: +def _configure_logging(log_dir: Optional[Path], level: str) -> None:      """Configure the loguru logger for output to terminal and disk.""" - -    def _get_level(verbose: int) -> str: -        """Sets the logger level.""" -        levels = {0: "WARNING", 1: "INFO", 2: "DEBUG"} -        verbose = min(verbose, 2) -        return levels[verbose] -      # Remove default logger to get tqdm to work properly.      logger.remove() - -    # Fetch verbosity level. -    level = _get_level(verbose) -      logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level)      if log_dir is not None:          logger.add( @@ -52,14 +39,6 @@ def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None:          ) -def _load_config(file_path: Path) -> DictConfig: -    """Return experiment config.""" -    logger.info(f"Loading config from: {file_path}") -    if not file_path.exists(): -        raise FileNotFoundError(f"Experiment config not found at: {file_path}") -    return OmegaConf.load(file_path) - -  def _import_class(module_and_class_name: str) -> type:      """Import class from module."""      module_name, class_name = module_and_class_name.rsplit(".", 1) @@ -78,14 +57,16 @@ def _configure_callbacks(  def _configure_logger( -    network: Type[nn.Module], args: Dict, log_dir: Path, use_wandb: bool +    network: Type[nn.Module], config: DictConfig, log_dir: Path  ) -> Type[pl.loggers.LightningLoggerBase]:      """Configures lightning logger.""" -    if use_wandb: +    if config.trainer.wandb: +        logger.info("Logging model with W&B")          pl_logger = pl.loggers.WandbLogger(save_dir=str(log_dir))          pl_logger.watch(network) -        pl_logger.log_hyperparams(vars(args)) +        pl_logger.log_hyperparams(vars(config))          return pl_logger +    logger.info("Logging model with Tensorboard")      return pl.loggers.TensorBoardLogger(save_dir=str(log_dir)) @@ -110,50 +91,36 @@ def _load_lit_model(      lit_model_class: type, network: Type[nn.Module], config: DictConfig  ) -> Type[pl.LightningModule]:      """Load lightning model.""" -    if config.load_checkpoint is not None: +    if config.trainer.load_checkpoint is not None:          logger.info( -            f"Loading network weights from checkpoint: {config.load_checkpoint}" +            f"Loading network weights from checkpoint: {config.trainer.load_checkpoint}"          )          return lit_model_class.load_from_checkpoint( -            config.load_checkpoint, network=network, **config.model.args +            config.trainer.load_checkpoint, network=network, **config.model.args          )      return lit_model_class(network=network, **config.model.args) -def run( -    filename: str, -    fast_dev_run: bool, -    train: bool, -    test: bool, -    tune: bool, -    use_wandb: bool, -    verbose: int = 0, -) -> None: +def run(config: DictConfig) -> None:      """Runs experiment.""" -    # Load config. -    file_path = CONFIGS_DIRNAME / filename -    config = _load_config(file_path) -      log_dir = _create_experiment_dir(config) -    _configure_logging(log_dir, verbose=verbose) +    _configure_logging(log_dir, level=config.trainer.logging)      logger.info("Starting experiment...") -    # Seed everything in the experiment. -    logger.info(f"Seeding everthing with seed={SEED}") -    pl.utilities.seed.seed_everything(SEED) +    pl.utilities.seed.seed_everything(config.trainer.seed)      # Load classes. -    data_module_class = _import_class(f"text_recognizer.data.{config.data.type}") +    data_module_class = _import_class(f"text_recognizer.data.{config.dataset.type}")      network_class = _import_class(f"text_recognizer.networks.{config.network.type}")      lit_model_class = _import_class(f"text_recognizer.models.{config.model.type}")      # Initialize data object and network. -    data_module = data_module_class(**config.data.args) +    data_module = data_module_class(**config.dataset.args)      network = network_class(**data_module.config(), **config.network.args)      # Load callback and logger.      callbacks = _configure_callbacks(config.callbacks) -    pl_logger = _configure_logger(network, config, log_dir, use_wandb) +    pl_logger = _configure_logger(network, config, log_dir)      # Load ligtning model.      lit_model = _load_lit_model(lit_model_class, network, config) @@ -164,55 +131,28 @@ def run(          logger=pl_logger,          weights_save_path=str(log_dir),      ) -    if fast_dev_run: -        logger.info("Fast dev run...") + +    if config.trainer.tune and not config.trainer.args.fast_dev_run: +        logger.info("Tuning learning rate and batch size...") +        trainer.tune(lit_model, datamodule=data_module) + +    if config.trainer.train: +        logger.info("Training network...")          trainer.fit(lit_model, datamodule=data_module) -    else: -        if tune: -            logger.info("Tuning learning rate and batch size...") -            trainer.tune(lit_model, datamodule=data_module) - -        if train: -            logger.info("Training network...") -            trainer.fit(lit_model, datamodule=data_module) - -        if test: -            logger.info("Testing network...") -            trainer.test(lit_model, datamodule=data_module) - -        _save_best_weights(callbacks, use_wandb) - - -@click.command() -@click.option("-f", "--experiment_config", type=str, help="Path to experiment config.") -@click.option("--use_wandb", is_flag=True, help="If true, do use wandb for logging.") -@click.option("--dev", is_flag=True, help="If true, run a fast dev run.") -@click.option( -    "--tune", is_flag=True, help="If true, tune hyperparameters for training." -) -@click.option("-t", "--train", is_flag=True, help="If true, train the model.") -@click.option("-e", "--test", is_flag=True, help="If true, test the model.") -@click.option("-v", "--verbose", count=True) -def cli( -    experiment_config: str, -    use_wandb: bool, -    dev: bool, -    tune: bool, -    train: bool, -    test: bool, -    verbose: int, -) -> None: -    """Run experiment.""" -    run( -        filename=experiment_config, -        fast_dev_run=dev, -        train=train, -        test=test, -        tune=tune, -        use_wandb=use_wandb, -        verbose=verbose, -    ) + +    if config.trainer.test and not config.trainer.args.fast_dev_run: +        logger.info("Testing network...") +        trainer.test(lit_model, datamodule=data_module) + +    if not config.trainer.args.fast_dev_run: +        _save_best_weights(callbacks, config.trainer.wandb) + + +@hydra.main(config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: +    """Loads config with hydra.""" +    run(cfg)  if __name__ == "__main__": -    cli() +    main()  |