summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/.gitignore1
-rw-r--r--training/conf/callbacks/default.yaml14
-rw-r--r--training/conf/callbacks/swa.yaml16
-rw-r--r--training/conf/cnn_transformer.yaml (renamed from training/configs/cnn_transformer.yaml)0
-rw-r--r--training/conf/config.yaml6
-rw-r--r--training/conf/dataset/iam_extended_paragraphs.yaml7
-rw-r--r--training/conf/model/lit_vqvae.yaml24
-rw-r--r--training/conf/network/vqvae.yaml14
-rw-r--r--training/conf/trainer/default.yaml18
-rw-r--r--training/configs/vqvae.yaml89
-rw-r--r--training/run_experiment.py136
11 files changed, 138 insertions, 187 deletions
diff --git a/training/.gitignore b/training/.gitignore
index 333c1e9..7d268ea 100644
--- a/training/.gitignore
+++ b/training/.gitignore
@@ -1 +1,2 @@
logs/
+outputs/
diff --git a/training/conf/callbacks/default.yaml b/training/conf/callbacks/default.yaml
new file mode 100644
index 0000000..74dc30c
--- /dev/null
+++ b/training/conf/callbacks/default.yaml
@@ -0,0 +1,14 @@
+# @package _group_
+- type: ModelCheckpoint
+ args:
+ monitor: val_loss
+ mode: min
+ save_last: true
+- type: LearningRateMonitor
+ args:
+ logging_interval: step
+# - type: EarlyStopping
+# args:
+# monitor: val_loss
+# mode: min
+# patience: 10
diff --git a/training/conf/callbacks/swa.yaml b/training/conf/callbacks/swa.yaml
new file mode 100644
index 0000000..144ad6e
--- /dev/null
+++ b/training/conf/callbacks/swa.yaml
@@ -0,0 +1,16 @@
+# @package _group_
+- type: ModelCheckpoint
+ args:
+ monitor: val_loss
+ mode: min
+ save_last: true
+- type: StochasticWeightAveraging
+ args:
+ swa_epoch_start: 0.8
+ swa_lrs: 0.05
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
+- type: LearningRateMonitor
+ args:
+ logging_interval: step
diff --git a/training/configs/cnn_transformer.yaml b/training/conf/cnn_transformer.yaml
index a4f16df..a4f16df 100644
--- a/training/configs/cnn_transformer.yaml
+++ b/training/conf/cnn_transformer.yaml
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
new file mode 100644
index 0000000..11adeb7
--- /dev/null
+++ b/training/conf/config.yaml
@@ -0,0 +1,6 @@
+defaults:
+ - network: vqvae
+ - model: lit_vqvae
+ - dataset: iam_extended_paragraphs
+ - trainer: default
+ - callbacks: default
diff --git a/training/conf/dataset/iam_extended_paragraphs.yaml b/training/conf/dataset/iam_extended_paragraphs.yaml
new file mode 100644
index 0000000..6bd7fc9
--- /dev/null
+++ b/training/conf/dataset/iam_extended_paragraphs.yaml
@@ -0,0 +1,7 @@
+# @package _group_
+type: IAMExtendedParagraphs
+args:
+ batch_size: 32
+ num_workers: 12
+ train_fraction: 0.8
+ augment: true
diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml
new file mode 100644
index 0000000..90780b7
--- /dev/null
+++ b/training/conf/model/lit_vqvae.yaml
@@ -0,0 +1,24 @@
+# @package _group_
+type: LitVQVAEModel
+args:
+ optimizer:
+ type: MADGRAD
+ args:
+ lr: 1.0e-3
+ momentum: 0.9
+ weight_decay: 0
+ eps: 1.0e-6
+ lr_scheduler:
+ type: OneCycleLR
+ args:
+ interval: step
+ max_lr: 1.0e-3
+ three_phase: true
+ epochs: 64
+ steps_per_epoch: 633 # num_samples / batch_size
+ criterion:
+ type: MSELoss
+ args:
+ reduction: mean
+ monitor: val_loss
+ mapping: sentence_piece
diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml
new file mode 100644
index 0000000..8c30bbd
--- /dev/null
+++ b/training/conf/network/vqvae.yaml
@@ -0,0 +1,14 @@
+# @package _group_
+type: VQVAE
+args:
+ in_channels: 1
+ channels: [32, 64, 64]
+ kernel_sizes: [4, 4, 4]
+ strides: [2, 2, 2]
+ num_residual_layers: 2
+ embedding_dim: 64
+ num_embeddings: 256
+ upsampling: null
+ beta: 0.25
+ activation: leaky_relu
+ dropout_rate: 0.2
diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml
new file mode 100644
index 0000000..82afd93
--- /dev/null
+++ b/training/conf/trainer/default.yaml
@@ -0,0 +1,18 @@
+# @package _group_
+seed: 4711
+load_checkpoint: null
+wandb: false
+tune: false
+train: true
+test: true
+logging: INFO
+args:
+ stochastic_weight_avg: false
+ auto_scale_batch_size: binsearch
+ gradient_clip_val: 0
+ fast_dev_run: false
+ gpus: 1
+ precision: 16
+ max_epochs: 64
+ terminate_on_nan: true
+ weights_summary: top
diff --git a/training/configs/vqvae.yaml b/training/configs/vqvae.yaml
deleted file mode 100644
index 13d7c97..0000000
--- a/training/configs/vqvae.yaml
+++ /dev/null
@@ -1,89 +0,0 @@
-seed: 4711
-
-network:
- desc: Configuration of the PyTorch neural network.
- type: VQVAE
- args:
- in_channels: 1
- channels: [32, 64, 64, 96, 96]
- kernel_sizes: [4, 4, 4, 4, 4]
- strides: [2, 2, 2, 2, 2]
- num_residual_layers: 2
- embedding_dim: 512
- num_embeddings: 1024
- upsampling: null
- beta: 0.25
- activation: leaky_relu
- dropout_rate: 0.2
-
-model:
- desc: Configuration of the PyTorch Lightning model.
- type: LitVQVAEModel
- args:
- optimizer:
- type: MADGRAD
- args:
- lr: 1.0e-3
- momentum: 0.9
- weight_decay: 0
- eps: 1.0e-6
- lr_scheduler:
- type: OneCycleLR
- args:
- interval: &interval step
- max_lr: 1.0e-3
- three_phase: true
- epochs: 64
- steps_per_epoch: 633 # num_samples / batch_size
- criterion:
- type: MSELoss
- args:
- reduction: mean
- monitor: val_loss
- mapping: sentence_piece
-
-data:
- desc: Configuration of the training/test data.
- type: IAMExtendedParagraphs
- args:
- batch_size: 32
- num_workers: 12
- train_fraction: 0.8
- augment: true
-
-callbacks:
- - type: ModelCheckpoint
- args:
- monitor: val_loss
- mode: min
- save_last: true
- - type: StochasticWeightAveraging
- args:
- swa_epoch_start: 0.8
- swa_lrs: 0.05
- annealing_epochs: 10
- annealing_strategy: cos
- device: null
- - type: LearningRateMonitor
- args:
- logging_interval: *interval
- # - type: EarlyStopping
- # args:
- # monitor: val_loss
- # mode: min
- # patience: 10
-
-trainer:
- desc: Configuration of the PyTorch Lightning Trainer.
- args:
- stochastic_weight_avg: true
- auto_scale_batch_size: binsearch
- gradient_clip_val: 0
- fast_dev_run: false
- gpus: 1
- precision: 16
- max_epochs: 64
- terminate_on_nan: true
- weights_summary: top
-
-load_checkpoint: null
diff --git a/training/run_experiment.py b/training/run_experiment.py
index bdefbf0..2b3ecab 100644
--- a/training/run_experiment.py
+++ b/training/run_experiment.py
@@ -4,17 +4,15 @@ import importlib
from pathlib import Path
from typing import Dict, List, Optional, Type
-import click
+import hydra
from loguru import logger
-from omegaconf import DictConfig, OmegaConf
+from omegaconf import DictConfig
import pytorch_lightning as pl
from torch import nn
from tqdm import tqdm
import wandb
-SEED = 4711
-CONFIGS_DIRNAME = Path(__file__).parent.resolve() / "configs"
LOGS_DIRNAME = Path(__file__).parent.resolve() / "logs"
@@ -29,21 +27,10 @@ def _create_experiment_dir(config: DictConfig) -> Path:
return log_dir
-def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None:
+def _configure_logging(log_dir: Optional[Path], level: str) -> None:
"""Configure the loguru logger for output to terminal and disk."""
-
- def _get_level(verbose: int) -> str:
- """Sets the logger level."""
- levels = {0: "WARNING", 1: "INFO", 2: "DEBUG"}
- verbose = min(verbose, 2)
- return levels[verbose]
-
# Remove default logger to get tqdm to work properly.
logger.remove()
-
- # Fetch verbosity level.
- level = _get_level(verbose)
-
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level)
if log_dir is not None:
logger.add(
@@ -52,14 +39,6 @@ def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None:
)
-def _load_config(file_path: Path) -> DictConfig:
- """Return experiment config."""
- logger.info(f"Loading config from: {file_path}")
- if not file_path.exists():
- raise FileNotFoundError(f"Experiment config not found at: {file_path}")
- return OmegaConf.load(file_path)
-
-
def _import_class(module_and_class_name: str) -> type:
"""Import class from module."""
module_name, class_name = module_and_class_name.rsplit(".", 1)
@@ -78,14 +57,16 @@ def _configure_callbacks(
def _configure_logger(
- network: Type[nn.Module], args: Dict, log_dir: Path, use_wandb: bool
+ network: Type[nn.Module], config: DictConfig, log_dir: Path
) -> Type[pl.loggers.LightningLoggerBase]:
"""Configures lightning logger."""
- if use_wandb:
+ if config.trainer.wandb:
+ logger.info("Logging model with W&B")
pl_logger = pl.loggers.WandbLogger(save_dir=str(log_dir))
pl_logger.watch(network)
- pl_logger.log_hyperparams(vars(args))
+ pl_logger.log_hyperparams(vars(config))
return pl_logger
+ logger.info("Logging model with Tensorboard")
return pl.loggers.TensorBoardLogger(save_dir=str(log_dir))
@@ -110,50 +91,36 @@ def _load_lit_model(
lit_model_class: type, network: Type[nn.Module], config: DictConfig
) -> Type[pl.LightningModule]:
"""Load lightning model."""
- if config.load_checkpoint is not None:
+ if config.trainer.load_checkpoint is not None:
logger.info(
- f"Loading network weights from checkpoint: {config.load_checkpoint}"
+ f"Loading network weights from checkpoint: {config.trainer.load_checkpoint}"
)
return lit_model_class.load_from_checkpoint(
- config.load_checkpoint, network=network, **config.model.args
+ config.trainer.load_checkpoint, network=network, **config.model.args
)
return lit_model_class(network=network, **config.model.args)
-def run(
- filename: str,
- fast_dev_run: bool,
- train: bool,
- test: bool,
- tune: bool,
- use_wandb: bool,
- verbose: int = 0,
-) -> None:
+def run(config: DictConfig) -> None:
"""Runs experiment."""
- # Load config.
- file_path = CONFIGS_DIRNAME / filename
- config = _load_config(file_path)
-
log_dir = _create_experiment_dir(config)
- _configure_logging(log_dir, verbose=verbose)
+ _configure_logging(log_dir, level=config.trainer.logging)
logger.info("Starting experiment...")
- # Seed everything in the experiment.
- logger.info(f"Seeding everthing with seed={SEED}")
- pl.utilities.seed.seed_everything(SEED)
+ pl.utilities.seed.seed_everything(config.trainer.seed)
# Load classes.
- data_module_class = _import_class(f"text_recognizer.data.{config.data.type}")
+ data_module_class = _import_class(f"text_recognizer.data.{config.dataset.type}")
network_class = _import_class(f"text_recognizer.networks.{config.network.type}")
lit_model_class = _import_class(f"text_recognizer.models.{config.model.type}")
# Initialize data object and network.
- data_module = data_module_class(**config.data.args)
+ data_module = data_module_class(**config.dataset.args)
network = network_class(**data_module.config(), **config.network.args)
# Load callback and logger.
callbacks = _configure_callbacks(config.callbacks)
- pl_logger = _configure_logger(network, config, log_dir, use_wandb)
+ pl_logger = _configure_logger(network, config, log_dir)
# Load ligtning model.
lit_model = _load_lit_model(lit_model_class, network, config)
@@ -164,55 +131,28 @@ def run(
logger=pl_logger,
weights_save_path=str(log_dir),
)
- if fast_dev_run:
- logger.info("Fast dev run...")
+
+ if config.trainer.tune and not config.trainer.args.fast_dev_run:
+ logger.info("Tuning learning rate and batch size...")
+ trainer.tune(lit_model, datamodule=data_module)
+
+ if config.trainer.train:
+ logger.info("Training network...")
trainer.fit(lit_model, datamodule=data_module)
- else:
- if tune:
- logger.info("Tuning learning rate and batch size...")
- trainer.tune(lit_model, datamodule=data_module)
-
- if train:
- logger.info("Training network...")
- trainer.fit(lit_model, datamodule=data_module)
-
- if test:
- logger.info("Testing network...")
- trainer.test(lit_model, datamodule=data_module)
-
- _save_best_weights(callbacks, use_wandb)
-
-
-@click.command()
-@click.option("-f", "--experiment_config", type=str, help="Path to experiment config.")
-@click.option("--use_wandb", is_flag=True, help="If true, do use wandb for logging.")
-@click.option("--dev", is_flag=True, help="If true, run a fast dev run.")
-@click.option(
- "--tune", is_flag=True, help="If true, tune hyperparameters for training."
-)
-@click.option("-t", "--train", is_flag=True, help="If true, train the model.")
-@click.option("-e", "--test", is_flag=True, help="If true, test the model.")
-@click.option("-v", "--verbose", count=True)
-def cli(
- experiment_config: str,
- use_wandb: bool,
- dev: bool,
- tune: bool,
- train: bool,
- test: bool,
- verbose: int,
-) -> None:
- """Run experiment."""
- run(
- filename=experiment_config,
- fast_dev_run=dev,
- train=train,
- test=test,
- tune=tune,
- use_wandb=use_wandb,
- verbose=verbose,
- )
+
+ if config.trainer.test and not config.trainer.args.fast_dev_run:
+ logger.info("Testing network...")
+ trainer.test(lit_model, datamodule=data_module)
+
+ if not config.trainer.args.fast_dev_run:
+ _save_best_weights(callbacks, config.trainer.wandb)
+
+
+@hydra.main(config_path="conf", config_name="config")
+def main(cfg: DictConfig) -> None:
+ """Loads config with hydra."""
+ run(cfg)
if __name__ == "__main__":
- cli()
+ main()