From 1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 22 Apr 2021 08:15:58 +0200 Subject: Fixed training script, able to train vqvae --- training/.gitignore | 1 + training/configs/image_transformer.yaml | 2 + training/configs/vqvae.yaml | 89 +++++++++++++++++++++++++++++++++ training/run_experiment.py | 66 +++++++++++++++--------- 4 files changed, 133 insertions(+), 25 deletions(-) create mode 100644 training/.gitignore create mode 100644 training/configs/vqvae.yaml (limited to 'training') diff --git a/training/.gitignore b/training/.gitignore new file mode 100644 index 0000000..333c1e9 --- /dev/null +++ b/training/.gitignore @@ -0,0 +1 @@ +logs/ diff --git a/training/configs/image_transformer.yaml b/training/configs/image_transformer.yaml index 228e53f..e6637f2 100644 --- a/training/configs/image_transformer.yaml +++ b/training/configs/image_transformer.yaml @@ -85,3 +85,5 @@ trainer: max_epochs: 512 terminate_on_nan: true weights_summary: true + +load_checkpoint: null diff --git a/training/configs/vqvae.yaml b/training/configs/vqvae.yaml new file mode 100644 index 0000000..90082f7 --- /dev/null +++ b/training/configs/vqvae.yaml @@ -0,0 +1,89 @@ +seed: 4711 + +network: + desc: Configuration of the PyTorch neural network. + type: VQVAE + args: + in_channels: 1 + channels: [32, 64, 96] + kernel_sizes: [4, 4, 4] + strides: [2, 2, 2] + num_residual_layers: 2 + embedding_dim: 64 + num_embeddings: 1024 + upsampling: null + beta: 0.25 + activation: leaky_relu + dropout_rate: 0.1 + +model: + desc: Configuration of the PyTorch Lightning model. + type: LitVQVAEModel + args: + optimizer: + type: MADGRAD + args: + lr: 1.0e-3 + momentum: 0.9 + weight_decay: 0 + eps: 1.0e-6 + lr_scheduler: + type: OneCycleLR + args: + interval: &interval step + max_lr: 1.0e-3 + three_phase: true + epochs: 512 + steps_per_epoch: 317 # num_samples / batch_size + criterion: + type: MSELoss + args: + reduction: mean + monitor: val_loss + mapping: sentence_piece + +data: + desc: Configuration of the training/test data. + type: IAMExtendedParagraphs + args: + batch_size: 64 + num_workers: 12 + train_fraction: 0.8 + augment: true + +callbacks: + - type: ModelCheckpoint + args: + monitor: val_loss + mode: min + save_last: true + # - type: StochasticWeightAveraging + # args: + # swa_epoch_start: 0.8 + # swa_lrs: 0.05 + # annealing_epochs: 10 + # annealing_strategy: cos + # device: null + - type: LearningRateMonitor + args: + logging_interval: *interval + - type: EarlyStopping + args: + monitor: val_loss + mode: min + patience: 10 + +trainer: + desc: Configuration of the PyTorch Lightning Trainer. + args: + stochastic_weight_avg: false # true + auto_scale_batch_size: binsearch + gradient_clip_val: 0 + fast_dev_run: false + gpus: 1 + precision: 16 + max_epochs: 512 + terminate_on_nan: true + weights_summary: full + +load_checkpoint: null diff --git a/training/run_experiment.py b/training/run_experiment.py index 1e79461..e1aae4e 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -1,4 +1,5 @@ """Script to run experiments.""" +from datetime import datetime import importlib from pathlib import Path from typing import Dict, List, Optional, Type @@ -7,7 +8,6 @@ import click from loguru import logger from omegaconf import DictConfig, OmegaConf import pytorch_lightning as pl -import torch from torch import nn from tqdm import tqdm import wandb @@ -18,6 +18,17 @@ CONFIGS_DIRNAME = Path(__file__).parent.resolve() / "configs" LOGS_DIRNAME = Path(__file__).parent.resolve() / "logs" +def _create_experiment_dir(config: DictConfig) -> Path: + """Creates log directory for experiment.""" + log_dir = ( + LOGS_DIRNAME + / f"{config.model.type}_{config.network.type}" + / datetime.now().strftime("%m%d_%H%M%S") + ) + log_dir.mkdir(parents=True, exist_ok=True) + return log_dir + + def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None: """Configure the loguru logger for output to terminal and disk.""" @@ -67,15 +78,15 @@ def _configure_callbacks( def _configure_logger( - network: Type[nn.Module], args: Dict, log_dir: str, use_wandb: bool + network: Type[nn.Module], args: Dict, log_dir: Path, use_wandb: bool ) -> Type[pl.loggers.LightningLoggerBase]: """Configures lightning logger.""" if use_wandb: - pl_logger = pl.loggers.WandbLogger(save_dir=log_dir) + pl_logger = pl.loggers.WandbLogger(save_dir=str(log_dir)) pl_logger.watch(network) pl_logger.log_hyperparams(vars(args)) return pl_logger - return pl.logger.TensorBoardLogger(save_dir=log_dir) + return pl.loggers.TensorBoardLogger(save_dir=str(log_dir)) def _save_best_weights( @@ -111,6 +122,7 @@ def _load_lit_model( def run( filename: str, + fast_dev_run: bool, train: bool, test: bool, tune: bool, @@ -118,20 +130,18 @@ def run( verbose: int = 0, ) -> None: """Runs experiment.""" - # Set log dir where logging output and weights are saved to. - log_dir = str(LOGS_DIRNAME) + # Load config. + file_path = CONFIGS_DIRNAME / filename + config = _load_config(file_path) - _configure_logging(None, verbose=verbose) + log_dir = _create_experiment_dir(config) + _configure_logging(log_dir, verbose=verbose) logger.info("Starting experiment...") # Seed everything in the experiment. logger.info(f"Seeding everthing with seed={SEED}") pl.utilities.seed.seed_everything(SEED) - # Load config. - file_path = CONFIGS_DIRNAME / filename - config = _load_config(file_path) - # Load classes. data_module_class = _import_class(f"text_recognizer.data.{config.data.type}") network_class = _import_class(f"text_recognizer.networks.{config.network.type}") @@ -152,36 +162,41 @@ def run( **config.trainer.args, callbacks=callbacks, logger=pl_logger, - weigths_save_path=log_dir, + weights_save_path=str(log_dir), ) - - if tune: - logger.info(f"Tuning learning rate and batch size...") - trainer.tune(lit_model, datamodule=data_module) - - if train: - logger.info(f"Training network...") + if fast_dev_run: + logger.info("Fast dev run...") trainer.fit(lit_model, datamodule=data_module) + else: + if tune: + logger.info("Tuning learning rate and batch size...") + trainer.tune(lit_model, datamodule=data_module) + + if train: + logger.info("Training network...") + trainer.fit(lit_model, datamodule=data_module) - if test: - logger.info(f"Testing network...") - trainer.test(lit_model, datamodule=data_module) + if test: + logger.info("Testing network...") + trainer.test(lit_model, datamodule=data_module) - _save_best_weights(callbacks, use_wandb) + _save_best_weights(callbacks, use_wandb) @click.command() @click.option("-f", "--experiment_config", type=str, help="Path to experiment config.") @click.option("--use_wandb", is_flag=True, help="If true, do use wandb for logging.") +@click.option("--dev", is_flag=True, help="If true, run a fast dev run.") @click.option( "--tune", is_flag=True, help="If true, tune hyperparameters for training." ) -@click.option("--train", is_flag=True, help="If true, train the model.") -@click.option("--test", is_flag=True, help="If true, test the model.") +@click.option("-t", "--train", is_flag=True, help="If true, train the model.") +@click.option("-e", "--test", is_flag=True, help="If true, test the model.") @click.option("-v", "--verbose", count=True) def cli( experiment_config: str, use_wandb: bool, + dev: bool, tune: bool, train: bool, test: bool, @@ -190,6 +205,7 @@ def cli( """Run experiment.""" run( filename=experiment_config, + fast_dev_run=dev, train=train, test=test, tune=tune, -- cgit v1.2.3-70-g09d2