summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/.gitignore1
-rw-r--r--training/configs/image_transformer.yaml2
-rw-r--r--training/configs/vqvae.yaml89
-rw-r--r--training/run_experiment.py66
4 files changed, 133 insertions, 25 deletions
diff --git a/training/.gitignore b/training/.gitignore
new file mode 100644
index 0000000..333c1e9
--- /dev/null
+++ b/training/.gitignore
@@ -0,0 +1 @@
+logs/
diff --git a/training/configs/image_transformer.yaml b/training/configs/image_transformer.yaml
index 228e53f..e6637f2 100644
--- a/training/configs/image_transformer.yaml
+++ b/training/configs/image_transformer.yaml
@@ -85,3 +85,5 @@ trainer:
max_epochs: 512
terminate_on_nan: true
weights_summary: true
+
+load_checkpoint: null
diff --git a/training/configs/vqvae.yaml b/training/configs/vqvae.yaml
new file mode 100644
index 0000000..90082f7
--- /dev/null
+++ b/training/configs/vqvae.yaml
@@ -0,0 +1,89 @@
+seed: 4711
+
+network:
+ desc: Configuration of the PyTorch neural network.
+ type: VQVAE
+ args:
+ in_channels: 1
+ channels: [32, 64, 96]
+ kernel_sizes: [4, 4, 4]
+ strides: [2, 2, 2]
+ num_residual_layers: 2
+ embedding_dim: 64
+ num_embeddings: 1024
+ upsampling: null
+ beta: 0.25
+ activation: leaky_relu
+ dropout_rate: 0.1
+
+model:
+ desc: Configuration of the PyTorch Lightning model.
+ type: LitVQVAEModel
+ args:
+ optimizer:
+ type: MADGRAD
+ args:
+ lr: 1.0e-3
+ momentum: 0.9
+ weight_decay: 0
+ eps: 1.0e-6
+ lr_scheduler:
+ type: OneCycleLR
+ args:
+ interval: &interval step
+ max_lr: 1.0e-3
+ three_phase: true
+ epochs: 512
+ steps_per_epoch: 317 # num_samples / batch_size
+ criterion:
+ type: MSELoss
+ args:
+ reduction: mean
+ monitor: val_loss
+ mapping: sentence_piece
+
+data:
+ desc: Configuration of the training/test data.
+ type: IAMExtendedParagraphs
+ args:
+ batch_size: 64
+ num_workers: 12
+ train_fraction: 0.8
+ augment: true
+
+callbacks:
+ - type: ModelCheckpoint
+ args:
+ monitor: val_loss
+ mode: min
+ save_last: true
+ # - type: StochasticWeightAveraging
+ # args:
+ # swa_epoch_start: 0.8
+ # swa_lrs: 0.05
+ # annealing_epochs: 10
+ # annealing_strategy: cos
+ # device: null
+ - type: LearningRateMonitor
+ args:
+ logging_interval: *interval
+ - type: EarlyStopping
+ args:
+ monitor: val_loss
+ mode: min
+ patience: 10
+
+trainer:
+ desc: Configuration of the PyTorch Lightning Trainer.
+ args:
+ stochastic_weight_avg: false # true
+ auto_scale_batch_size: binsearch
+ gradient_clip_val: 0
+ fast_dev_run: false
+ gpus: 1
+ precision: 16
+ max_epochs: 512
+ terminate_on_nan: true
+ weights_summary: full
+
+load_checkpoint: null
diff --git a/training/run_experiment.py b/training/run_experiment.py
index 1e79461..e1aae4e 100644
--- a/training/run_experiment.py
+++ b/training/run_experiment.py
@@ -1,4 +1,5 @@
"""Script to run experiments."""
+from datetime import datetime
import importlib
from pathlib import Path
from typing import Dict, List, Optional, Type
@@ -7,7 +8,6 @@ import click
from loguru import logger
from omegaconf import DictConfig, OmegaConf
import pytorch_lightning as pl
-import torch
from torch import nn
from tqdm import tqdm
import wandb
@@ -18,6 +18,17 @@ CONFIGS_DIRNAME = Path(__file__).parent.resolve() / "configs"
LOGS_DIRNAME = Path(__file__).parent.resolve() / "logs"
+def _create_experiment_dir(config: DictConfig) -> Path:
+ """Creates log directory for experiment."""
+ log_dir = (
+ LOGS_DIRNAME
+ / f"{config.model.type}_{config.network.type}"
+ / datetime.now().strftime("%m%d_%H%M%S")
+ )
+ log_dir.mkdir(parents=True, exist_ok=True)
+ return log_dir
+
+
def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None:
"""Configure the loguru logger for output to terminal and disk."""
@@ -67,15 +78,15 @@ def _configure_callbacks(
def _configure_logger(
- network: Type[nn.Module], args: Dict, log_dir: str, use_wandb: bool
+ network: Type[nn.Module], args: Dict, log_dir: Path, use_wandb: bool
) -> Type[pl.loggers.LightningLoggerBase]:
"""Configures lightning logger."""
if use_wandb:
- pl_logger = pl.loggers.WandbLogger(save_dir=log_dir)
+ pl_logger = pl.loggers.WandbLogger(save_dir=str(log_dir))
pl_logger.watch(network)
pl_logger.log_hyperparams(vars(args))
return pl_logger
- return pl.logger.TensorBoardLogger(save_dir=log_dir)
+ return pl.loggers.TensorBoardLogger(save_dir=str(log_dir))
def _save_best_weights(
@@ -111,6 +122,7 @@ def _load_lit_model(
def run(
filename: str,
+ fast_dev_run: bool,
train: bool,
test: bool,
tune: bool,
@@ -118,20 +130,18 @@ def run(
verbose: int = 0,
) -> None:
"""Runs experiment."""
- # Set log dir where logging output and weights are saved to.
- log_dir = str(LOGS_DIRNAME)
+ # Load config.
+ file_path = CONFIGS_DIRNAME / filename
+ config = _load_config(file_path)
- _configure_logging(None, verbose=verbose)
+ log_dir = _create_experiment_dir(config)
+ _configure_logging(log_dir, verbose=verbose)
logger.info("Starting experiment...")
# Seed everything in the experiment.
logger.info(f"Seeding everthing with seed={SEED}")
pl.utilities.seed.seed_everything(SEED)
- # Load config.
- file_path = CONFIGS_DIRNAME / filename
- config = _load_config(file_path)
-
# Load classes.
data_module_class = _import_class(f"text_recognizer.data.{config.data.type}")
network_class = _import_class(f"text_recognizer.networks.{config.network.type}")
@@ -152,36 +162,41 @@ def run(
**config.trainer.args,
callbacks=callbacks,
logger=pl_logger,
- weigths_save_path=log_dir,
+ weights_save_path=str(log_dir),
)
-
- if tune:
- logger.info(f"Tuning learning rate and batch size...")
- trainer.tune(lit_model, datamodule=data_module)
-
- if train:
- logger.info(f"Training network...")
+ if fast_dev_run:
+ logger.info("Fast dev run...")
trainer.fit(lit_model, datamodule=data_module)
+ else:
+ if tune:
+ logger.info("Tuning learning rate and batch size...")
+ trainer.tune(lit_model, datamodule=data_module)
+
+ if train:
+ logger.info("Training network...")
+ trainer.fit(lit_model, datamodule=data_module)
- if test:
- logger.info(f"Testing network...")
- trainer.test(lit_model, datamodule=data_module)
+ if test:
+ logger.info("Testing network...")
+ trainer.test(lit_model, datamodule=data_module)
- _save_best_weights(callbacks, use_wandb)
+ _save_best_weights(callbacks, use_wandb)
@click.command()
@click.option("-f", "--experiment_config", type=str, help="Path to experiment config.")
@click.option("--use_wandb", is_flag=True, help="If true, do use wandb for logging.")
+@click.option("--dev", is_flag=True, help="If true, run a fast dev run.")
@click.option(
"--tune", is_flag=True, help="If true, tune hyperparameters for training."
)
-@click.option("--train", is_flag=True, help="If true, train the model.")
-@click.option("--test", is_flag=True, help="If true, test the model.")
+@click.option("-t", "--train", is_flag=True, help="If true, train the model.")
+@click.option("-e", "--test", is_flag=True, help="If true, test the model.")
@click.option("-v", "--verbose", count=True)
def cli(
experiment_config: str,
use_wandb: bool,
+ dev: bool,
tune: bool,
train: bool,
test: bool,
@@ -190,6 +205,7 @@ def cli(
"""Run experiment."""
run(
filename=experiment_config,
+ fast_dev_run=dev,
train=train,
test=test,
tune=tune,