summaryrefslogtreecommitdiff
path: root/src/training
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-08 23:14:23 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-08 23:14:23 +0200
commite1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (patch)
tree70b482f890c9ad2be104f0bff8f2172e8411a2be /src/training
parentfe23001b6588e6e6e9e2c5a99b72f3445cf5206f (diff)
IAM datasets implemented.
Diffstat (limited to 'src/training')
-rw-r--r--src/training/experiments/sample_experiment.yml127
-rw-r--r--src/training/prepare_experiments.py4
-rw-r--r--src/training/run_experiment.py238
-rw-r--r--src/training/run_sweep.py86
-rw-r--r--src/training/sweep_emnist.yml26
-rw-r--r--src/training/sweep_emnist_resnet.yml50
-rw-r--r--src/training/trainer/callbacks/__init__.py15
-rw-r--r--src/training/trainer/callbacks/base.py78
-rw-r--r--src/training/trainer/callbacks/checkpoint.py95
-rw-r--r--src/training/trainer/callbacks/lr_schedulers.py52
-rw-r--r--src/training/trainer/callbacks/progress_bar.py19
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py32
-rw-r--r--src/training/trainer/train.py170
-rw-r--r--src/training/trainer/util.py9
14 files changed, 686 insertions, 315 deletions
diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml
index b00bd5a..17e220e 100644
--- a/src/training/experiments/sample_experiment.yml
+++ b/src/training/experiments/sample_experiment.yml
@@ -1,17 +1,20 @@
experiment_group: Sample Experiments
experiments:
- - dataset: EmnistDataset
- dataset_args:
- sample_to_balance: true
- subsample_fraction: null
- transform: null
- target_transform: null
- seed: 4711
- data_loader_args:
- splits: [train, val]
- shuffle: true
- num_workers: 8
- cuda: true
+ - train_args:
+ batch_size: 256
+ max_epochs: 32
+ dataset:
+ type: EmnistDataset
+ args:
+ sample_to_balance: true
+ subsample_fraction: null
+ transform: null
+ target_transform: null
+ seed: 4711
+ train_args:
+ num_workers: 6
+ train_fraction: 0.8
+
model: CharacterModel
metrics: [accuracy]
# network: MLP
@@ -19,65 +22,81 @@ experiments:
# input_size: 784
# hidden_size: 512
# output_size: 80
- # num_layers: 3
- # dropout_rate: 0
+ # num_layers: 5
+ # dropout_rate: 0.2
# activation_fn: SELU
- network: ResidualNetwork
- network_args:
- in_channels: 1
- num_classes: 80
- depths: [2, 1]
- block_sizes: [96, 32]
+ network:
+ type: ResidualNetwork
+ args:
+ in_channels: 1
+ num_classes: 80
+ depths: [2, 2]
+ block_sizes: [64, 64]
+ activation: leaky_relu
+ stn: true
+ # network:
+ # type: WideResidualNetwork
+ # args:
+ # in_channels: 1
+ # num_classes: 80
+ # depth: 10
+ # num_layers: 3
+ # width_factor: 4
+ # dropout_rate: 0.2
+ # activation: SELU
# network: LeNet
# network_args:
# output_size: 62
# activation_fn: GELU
- train_args:
- batch_size: 256
- epochs: 32
- criterion: CrossEntropyLoss
- criterion_args:
- weight: null
- ignore_index: -100
- reduction: mean
- # optimizer: RMSprop
- # optimizer_args:
- # lr: 1.e-3
- # alpha: 0.9
- # eps: 1.e-7
- # momentum: 0
- # weight_decay: 0
- # centered: false
- optimizer: AdamW
- optimizer_args:
- lr: 1.e-03
- betas: [0.9, 0.999]
- eps: 1.e-08
- # weight_decay: 5.e-4
- amsgrad: false
- # lr_scheduler: null
- lr_scheduler: OneCycleLR
- lr_scheduler_args:
- max_lr: 1.e-03
- epochs: 32
- anneal_strategy: linear
- callbacks: [Checkpoint, ProgressBar, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR]
+ criterion:
+ type: CrossEntropyLoss
+ args:
+ weight: null
+ ignore_index: -100
+ reduction: mean
+ optimizer:
+ type: AdamW
+ args:
+ lr: 1.e-02
+ betas: [0.9, 0.999]
+ eps: 1.e-08
+ # weight_decay: 5.e-4
+ amsgrad: false
+ # lr_scheduler:
+ # type: OneCycleLR
+ # args:
+ # max_lr: 1.e-03
+ # epochs: null
+ # anneal_strategy: linear
+ lr_scheduler:
+ type: CosineAnnealingLR
+ args:
+ T_max: null
+ swa_args:
+ start: 2
+ lr: 5.e-2
+ callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, EarlyStopping, SWA] # OneCycleLR]
callback_args:
Checkpoint:
monitor: val_accuracy
ProgressBar:
- epochs: 32
+ epochs: null
log_batch_frequency: 100
EarlyStopping:
monitor: val_loss
min_delta: 0.0
- patience: 3
+ patience: 5
mode: min
WandbCallback:
log_batch_frequency: 10
WandbImageLogger:
num_examples: 4
- OneCycleLR:
+ use_transpose: true
+ # OneCycleLR:
+ # null
+ SWA:
null
- verbosity: 1 # 0, 1, 2
+ verbosity: 0 # 0, 1, 2
resume_experiment: null
+ test: true
+ test_metric: test_accuracy
diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py
index 4c3f9ba..e00540c 100644
--- a/src/training/prepare_experiments.py
+++ b/src/training/prepare_experiments.py
@@ -9,14 +9,14 @@ import yaml
def run_experiments(experiments_filename: str) -> None:
"""Run experiment from file."""
- with open(experiments_filename) as f:
+ with open(experiments_filename, "r") as f:
experiments_config = yaml.safe_load(f)
num_experiments = len(experiments_config["experiments"])
for index in range(num_experiments):
experiment_config = experiments_config["experiments"][index]
experiment_config["experiment_group"] = experiments_config["experiment_group"]
- cmd = f"python training/run_experiment.py --gpu=-1 --save --experiment_config='{json.dumps(experiment_config)}'"
+ cmd = f"python training/run_experiment.py --gpu=-1 --save '{json.dumps(experiment_config)}'"
print(cmd)
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index 8c063ff..4317d66 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -6,18 +6,19 @@ import json
import os
from pathlib import Path
import re
-from typing import Callable, Dict, Tuple, Type
+from typing import Callable, Dict, List, Tuple, Type
import click
from loguru import logger
import torch
from tqdm import tqdm
from training.gpu_manager import GPUManager
-from training.trainer.callbacks import CallbackList
+from training.trainer.callbacks import Callback, CallbackList
from training.trainer.train import Trainer
import wandb
import yaml
+
from text_recognizer.models import Model
@@ -37,10 +38,14 @@ def get_level(experiment_config: Dict) -> int:
return 10
-def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path:
+def create_experiment_dir(experiment_config: Dict) -> Path:
"""Create new experiment."""
EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True)
- experiment_dir = EXPERIMENTS_DIRNAME / model.__name__
+ experiment_dir = EXPERIMENTS_DIRNAME / (
+ f"{experiment_config['model']}_"
+ + f"{experiment_config['dataset']['type']}_"
+ + f"{experiment_config['network']['type']}"
+ )
if experiment_config["resume_experiment"] is None:
experiment = datetime.now().strftime("%m%d_%H%M%S")
logger.debug(f"Creating a new experiment called {experiment}")
@@ -54,70 +59,89 @@ def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path:
experiment = experiment_config["resume_experiment"]
if not str(experiment_dir / experiment) in available_experiments:
raise FileNotFoundError("Experiment does not exist.")
- logger.debug(f"Resuming the experiment {experiment}")
experiment_dir = experiment_dir / experiment
- return experiment_dir
+ # Create log and model directories.
+ log_dir = experiment_dir / "log"
+ model_dir = experiment_dir / "model"
+
+ return experiment_dir, log_dir, model_dir
-def check_args(args: Dict) -> Dict:
+
+def check_args(args: Dict, train_args: Dict) -> Dict:
"""Checks that the arguments are not None."""
+ args = args or {}
+
+ # I just want to set total epochs in train args, instead of changing all parameter.
+ if "epochs" in args and args["epochs"] is None:
+ args["epochs"] = train_args["max_epochs"]
+
+ # For CosineAnnealingLR.
+ if "T_max" in args and args["T_max"] is None:
+ args["T_max"] = train_args["max_epochs"]
+
return args or {}
def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]:
"""Loads all modules and arguments."""
# Import the data loader arguments.
- data_loader_args = experiment_config.get("data_loader_args", {})
train_args = experiment_config.get("train_args", {})
- data_loader_args["batch_size"] = train_args["batch_size"]
- data_loader_args["dataset"] = experiment_config["dataset"]
- data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {})
+
+ # Load the dataset module.
+ dataset_args = experiment_config.get("dataset", {})
+ dataset_args["train_args"]["batch_size"] = train_args["batch_size"]
+ datasets_module = importlib.import_module("text_recognizer.datasets")
+ dataset_ = getattr(datasets_module, dataset_args["type"])
# Import the model module and model arguments.
models_module = importlib.import_module("text_recognizer.models")
model_class_ = getattr(models_module, experiment_config["model"])
# Import metrics.
- metric_fns_ = {
- metric: getattr(models_module, metric)
- for metric in experiment_config["metrics"]
- }
+ metric_fns_ = (
+ {
+ metric: getattr(models_module, metric)
+ for metric in experiment_config["metrics"]
+ }
+ if experiment_config["metrics"] is not None
+ else None
+ )
# Import network module and arguments.
network_module = importlib.import_module("text_recognizer.networks")
- network_fn_ = getattr(network_module, experiment_config["network"])
- network_args = experiment_config.get("network_args", {})
+ network_fn_ = getattr(network_module, experiment_config["network"]["type"])
+ network_args = experiment_config["network"].get("args", {})
# Criterion
- criterion_ = getattr(torch.nn, experiment_config["criterion"])
- criterion_args = experiment_config.get("criterion_args", {})
+ criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"])
+ criterion_args = experiment_config["criterion"].get("args", {})
- # Optimizer
- optimizer_ = getattr(torch.optim, experiment_config["optimizer"])
- optimizer_args = experiment_config.get("optimizer_args", {})
-
- # Callbacks
- callback_modules = importlib.import_module("training.trainer.callbacks")
- callbacks = [
- getattr(callback_modules, callback)(
- **check_args(experiment_config["callback_args"][callback])
- )
- for callback in experiment_config["callbacks"]
- ]
+ # Optimizers
+ optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"])
+ optimizer_args = experiment_config["optimizer"].get("args", {})
# Learning rate scheduler
+ lr_scheduler_ = None
+ lr_scheduler_args = None
if experiment_config["lr_scheduler"] is not None:
lr_scheduler_ = getattr(
- torch.optim.lr_scheduler, experiment_config["lr_scheduler"]
+ torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"]
+ )
+ lr_scheduler_args = check_args(
+ experiment_config["lr_scheduler"].get("args", {}), train_args
)
- lr_scheduler_args = experiment_config.get("lr_scheduler_args", {})
+
+ # SWA scheduler.
+ if "swa_args" in experiment_config:
+ swa_args = check_args(experiment_config.get("swa_args", {}), train_args)
else:
- lr_scheduler_ = None
- lr_scheduler_args = None
+ swa_args = None
model_args = {
- "data_loader_args": data_loader_args,
+ "dataset": dataset_,
+ "dataset_args": dataset_args,
"metrics": metric_fns_,
"network_fn": network_fn_,
"network_args": network_args,
@@ -127,43 +151,33 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
"optimizer_args": optimizer_args,
"lr_scheduler": lr_scheduler_,
"lr_scheduler_args": lr_scheduler_args,
+ "swa_args": swa_args,
}
- return model_class_, model_args, callbacks
-
+ return model_class_, model_args
-def run_experiment(
- experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False
-) -> None:
- """Runs an experiment."""
-
- # Load the modules and model arguments.
- model_class_, model_args, callbacks = load_modules_and_arguments(experiment_config)
-
- # Initializes the model with experiment config.
- model = model_class_(**model_args, device=device)
- # Instantiate a CallbackList.
- callbacks = CallbackList(model, callbacks)
-
- # Create new experiment.
- experiment_dir = create_experiment_dir(model, experiment_config)
+def configure_callbacks(experiment_config: Dict, model_dir: Dict) -> CallbackList:
+ """Configure a callback list for trainer."""
+ train_args = experiment_config.get("train_args", {})
- # Create log and model directories.
- log_dir = experiment_dir / "log"
- model_dir = experiment_dir / "model"
+ if "Checkpoint" in experiment_config["callback_args"]:
+ experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = model_dir
- # Set the model dir to be able to save checkpoints.
- model.model_dir = model_dir
+ # Callbacks
+ callback_modules = importlib.import_module("training.trainer.callbacks")
+ callbacks = [
+ getattr(callback_modules, callback)(
+ **check_args(experiment_config["callback_args"][callback], train_args)
+ )
+ for callback in experiment_config["callbacks"]
+ ]
- # Get checkpoint path.
- checkpoint_path = model_dir / "last.pt"
- if not checkpoint_path.exists():
- checkpoint_path = None
+ return callbacks
- # Make sure the log directory exists.
- log_dir.mkdir(parents=True, exist_ok=True)
+def configure_logger(experiment_config: Dict, log_dir: Path) -> None:
+ """Configure the loguru logger for output to terminal and disk."""
# Have to remove default logger to get tqdm to work properly.
logger.remove()
@@ -176,13 +190,50 @@ def run_experiment(
format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
)
- if "cuda" in device:
- gpu_index = re.sub("[^0-9]+", "", device)
- logger.info(
- f"Running experiment with config {experiment_config} on GPU {gpu_index}"
- )
- else:
- logger.info(f"Running experiment with config {experiment_config} on CPU")
+
+def save_config(experiment_dir: Path, experiment_config: Dict) -> None:
+ """Copy config to experiment directory."""
+ config_path = experiment_dir / "config.yml"
+ with open(str(config_path), "w") as f:
+ yaml.dump(experiment_config, f)
+
+
+def load_from_checkpoint(model: Type[Model], log_dir: Path, model_dir: Path) -> None:
+ """If checkpoint exists, load model weights and optimizers from checkpoint."""
+ # Get checkpoint path.
+ checkpoint_path = model_dir / "last.pt"
+ if checkpoint_path.exists():
+ logger.info("Loading and resuming training from last checkpoint.")
+ model.load_checkpoint(checkpoint_path)
+
+
+def run_experiment(
+ experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False
+) -> None:
+ """Runs an experiment."""
+ logger.info(f"Experiment config: {json.dumps(experiment_config)}")
+
+ # Create new experiment.
+ experiment_dir, log_dir, model_dir = create_experiment_dir(experiment_config)
+
+ # Make sure the log/model directory exists.
+ log_dir.mkdir(parents=True, exist_ok=True)
+ model_dir.mkdir(parents=True, exist_ok=True)
+
+ # Load the modules and model arguments.
+ model_class_, model_args = load_modules_and_arguments(experiment_config)
+
+ # Initializes the model with experiment config.
+ model = model_class_(**model_args, device=device)
+
+ callbacks = configure_callbacks(experiment_config, model_dir)
+
+ # Setup logger.
+ configure_logger(experiment_config, log_dir)
+
+ # Load from checkpoint if resuming an experiment.
+ if experiment_config["resume_experiment"] is not None:
+ load_from_checkpoint(model, log_dir, model_dir)
logger.info(f"The class mapping is {model.mapping}")
@@ -193,9 +244,6 @@ def run_experiment(
# Lets W&B save the model and track the gradients and optional parameters.
wandb.watch(model.network)
- # PÅ•ints a summary of the network in terminal.
- model.summary()
-
experiment_config["train_args"] = {
**DEFAULT_TRAIN_ARGS,
**experiment_config.get("train_args", {}),
@@ -208,41 +256,41 @@ def run_experiment(
experiment_config["device"] = device
# Save the config used in the experiment folder.
- config_path = experiment_dir / "config.yml"
- with open(str(config_path), "w") as f:
- yaml.dump(experiment_config, f)
+ save_config(experiment_dir, experiment_config)
- # Train the model.
+ # Load trainer.
trainer = Trainer(
- model=model,
- model_dir=model_dir,
- train_args=experiment_config["train_args"],
- callbacks=callbacks,
- checkpoint_path=checkpoint_path,
+ max_epochs=experiment_config["train_args"]["max_epochs"], callbacks=callbacks,
)
- trainer.fit()
+ # Train the model.
+ trainer.fit(model)
- logger.info("Loading checkpoint with the best weights.")
- model.load_checkpoint(model_dir / "best.pt")
+ # Run inference over test set.
+ if experiment_config["test"]:
+ logger.info("Loading checkpoint with the best weights.")
+ model.load_from_checkpoint(model_dir / "best.pt")
- score = trainer.validate()
+ logger.info("Running inference on test set.")
+ score = trainer.test(model)
- logger.info(f"Validation set evaluation: {score}")
+ logger.info(f"Test set evaluation: {score}")
- if use_wandb:
- wandb.log({"validation_metric": score["val_accuracy"]})
+ if use_wandb:
+ wandb.log(
+ {
+ experiment_config["test_metric"]: score[
+ experiment_config["test_metric"]
+ ]
+ }
+ )
if save_weights:
model.save_weights(model_dir)
@click.command()
-@click.option(
- "--experiment_config",
- type=str,
- help='Experiment JSON, e.g. \'{"dataloader": "EmnistDataLoader", "model": "CharacterModel", "network": "mlp"}\'',
-)
+@click.argument("experiment_config",)
@click.option("--gpu", type=int, default=0, help="Provide the index of the GPU to use.")
@click.option(
"--save",
diff --git a/src/training/run_sweep.py b/src/training/run_sweep.py
index 5c5322a..a578592 100644
--- a/src/training/run_sweep.py
+++ b/src/training/run_sweep.py
@@ -2,7 +2,91 @@
from ast import literal_eval
import json
import os
+from pathlib import Path
import signal
import subprocess # nosec
import sys
-from typing import Tuple
+from typing import Dict, List, Tuple
+
+import click
+import yaml
+
+EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
+
+
+def load_config() -> Dict:
+ """Load base hyperparameter config."""
+ with open(str(EXPERIMENTS_DIRNAME / "default_config_emnist.yml"), "r") as f:
+ default_config = yaml.safe_load(f)
+ return default_config
+
+
+def args_to_json(
+ default_config: dict, preserve_args: tuple = ("gpu", "save")
+) -> Tuple[dict, list]:
+ """Convert command line arguments to nested config values.
+
+ i.e. run_sweep.py --dataset_args.foo=1.7
+ {
+ "dataset_args": {
+ "foo": 1.7
+ }
+ }
+
+ Args:
+ default_config (dict): The base config used for every experiment.
+ preserve_args (tuple): Arguments preserved for all runs. Defaults to ("gpu", "save").
+
+ Returns:
+ Tuple[dict, list]: Tuple of config dictionary and list of arguments.
+
+ """
+
+ args = []
+ config = default_config.copy()
+ key, val = None, None
+ for arg in sys.argv[1:]:
+ if "=" in arg:
+ key, val = arg.split("=")
+ elif key:
+ val = arg
+ else:
+ key = arg
+ if key and val:
+ parsed_key = key.lstrip("-").split(".")
+ if parsed_key[0] in preserve_args:
+ args.append("--{}={}".format(parsed_key[0], val))
+ else:
+ nested = config
+ for level in parsed_key[:-1]:
+ nested[level] = config.get(level, {})
+ nested = nested[level]
+ try:
+ # Convert numerics to floats / ints
+ val = literal_eval(val)
+ except ValueError:
+ pass
+ nested[parsed_key[-1]] = val
+ key, val = None, None
+ return config, args
+
+
+def main() -> None:
+ """Runs a W&B sweep."""
+ default_config = load_config()
+ config, args = args_to_json(default_config)
+ env = {
+ k: v for k, v in os.environ.items() if k not in ("WANDB_PROGRAM", "WANDB_ARGS")
+ }
+ # pylint: disable=subprocess-popen-preexec-fn
+ run = subprocess.Popen(
+ ["python", "training/run_experiment.py", *args, json.dumps(config)],
+ env=env,
+ preexec_fn=os.setsid,
+ ) # nosec
+ signal.signal(signal.SIGTERM, lambda *args: run.terminate())
+ run.wait()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/training/sweep_emnist.yml b/src/training/sweep_emnist.yml
new file mode 100644
index 0000000..48d7261
--- /dev/null
+++ b/src/training/sweep_emnist.yml
@@ -0,0 +1,26 @@
+program: training/run_sweep.py
+method: bayes
+metric:
+ name: val_loss
+ goal: minimize
+parameters:
+ dataset:
+ value: EmnistDataset
+ model:
+ value: CharacterModel
+ network:
+ value: MLP
+ network_args.hidden_size:
+ values: [128, 256]
+ network_args.dropout_rate:
+ values: [0.2, 0.4]
+ network_args.num_layers:
+ values: [3, 6]
+ optimizer_args.lr:
+ values: [1.e-1, 1.e-5]
+ lr_scheduler_args.max_lr:
+ values: [1.0e-1, 1.0e-5]
+ train_args.batch_size:
+ values: [64, 128]
+ train_args.epochs:
+ value: 5
diff --git a/src/training/sweep_emnist_resnet.yml b/src/training/sweep_emnist_resnet.yml
new file mode 100644
index 0000000..19a3040
--- /dev/null
+++ b/src/training/sweep_emnist_resnet.yml
@@ -0,0 +1,50 @@
+program: training/run_sweep.py
+method: bayes
+metric:
+ name: val_accuracy
+ goal: maximize
+parameters:
+ dataset:
+ value: EmnistDataset
+ model:
+ value: CharacterModel
+ network:
+ value: ResidualNetwork
+ network_args.block_sizes:
+ distribution: q_uniform
+ min: 16
+ max: 256
+ q: 8
+ network_args.depths:
+ distribution: int_uniform
+ min: 1
+ max: 3
+ network_args.levels:
+ distribution: int_uniform
+ min: 1
+ max: 2
+ network_args.activation:
+ distribution: categorical
+ values:
+ - gelu
+ - leaky_relu
+ - relu
+ - selu
+ optimizer_args.lr:
+ distribution: uniform
+ min: 1.e-5
+ max: 1.e-1
+ lr_scheduler_args.max_lr:
+ distribution: uniform
+ min: 1.e-5
+ max: 1.e-1
+ train_args.batch_size:
+ distribution: q_uniform
+ min: 32
+ max: 256
+ q: 8
+ train_args.epochs:
+ value: 5
+early_terminate:
+ type: hyperband
+ min_iter: 2
diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py
index 5942276..c81e4bf 100644
--- a/src/training/trainer/callbacks/__init__.py
+++ b/src/training/trainer/callbacks/__init__.py
@@ -1,7 +1,16 @@
"""The callback modules used in the training script."""
-from .base import Callback, CallbackList, Checkpoint
+from .base import Callback, CallbackList
+from .checkpoint import Checkpoint
from .early_stopping import EarlyStopping
-from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR
+from .lr_schedulers import (
+ CosineAnnealingLR,
+ CyclicLR,
+ MultiStepLR,
+ OneCycleLR,
+ ReduceLROnPlateau,
+ StepLR,
+ SWA,
+)
from .progress_bar import ProgressBar
from .wandb_callbacks import WandbCallback, WandbImageLogger
@@ -9,6 +18,7 @@ __all__ = [
"Callback",
"CallbackList",
"Checkpoint",
+ "CosineAnnealingLR",
"EarlyStopping",
"WandbCallback",
"WandbImageLogger",
@@ -18,4 +28,5 @@ __all__ = [
"ProgressBar",
"ReduceLROnPlateau",
"StepLR",
+ "SWA",
]
diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py
index 8df94f3..8c7b085 100644
--- a/src/training/trainer/callbacks/base.py
+++ b/src/training/trainer/callbacks/base.py
@@ -168,81 +168,3 @@ class CallbackList:
def __iter__(self) -> iter:
"""Iter function for callback list."""
return iter(self._callbacks)
-
-
-class Checkpoint(Callback):
- """Saving model parameters at the end of each epoch."""
-
- mode_dict = {
- "min": torch.lt,
- "max": torch.gt,
- }
-
- def __init__(
- self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0
- ) -> None:
- """Monitors a quantity that will allow us to determine the best model weights.
-
- Args:
- monitor (str): Name of the quantity to monitor. Defaults to "accuracy".
- mode (str): Description of parameter `mode`. Defaults to "auto".
- min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
-
- """
- super().__init__()
- self.monitor = monitor
- self.mode = mode
- self.min_delta = torch.tensor(min_delta)
-
- if mode not in ["auto", "min", "max"]:
- logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.")
-
- self.mode = "auto"
-
- if self.mode == "auto":
- if "accuracy" in self.monitor:
- self.mode = "max"
- else:
- self.mode = "min"
- logger.debug(
- f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}."
- )
-
- torch_inf = torch.tensor(np.inf)
- self.min_delta *= 1 if self.monitor_op == torch.gt else -1
- self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
-
- @property
- def monitor_op(self) -> float:
- """Returns the comparison method."""
- return self.mode_dict[self.mode]
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Saves a checkpoint for the network parameters.
-
- Args:
- epoch (int): The current epoch.
- logs (Dict): The log containing the monitored metrics.
-
- """
- current = self.get_monitor_value(logs)
- if current is None:
- return
- if self.monitor_op(current - self.min_delta, self.best_score):
- self.best_score = current
- is_best = True
- else:
- is_best = False
-
- self.model.save_checkpoint(is_best, epoch, self.monitor)
-
- def get_monitor_value(self, logs: Dict) -> Union[float, None]:
- """Extracts the monitored value."""
- monitor_value = logs.get(self.monitor)
- if monitor_value is None:
- logger.warning(
- f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available"
- + f"metrics are: {','.join(list(logs.keys()))}"
- )
- return None
- return monitor_value
diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py
new file mode 100644
index 0000000..6fe06d3
--- /dev/null
+++ b/src/training/trainer/callbacks/checkpoint.py
@@ -0,0 +1,95 @@
+"""Callback checkpoint for training models."""
+from enum import Enum
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from loguru import logger
+import numpy as np
+import torch
+from training.trainer.callbacks import Callback
+
+from text_recognizer.models import Model
+
+
+class Checkpoint(Callback):
+ """Saving model parameters at the end of each epoch."""
+
+ mode_dict = {
+ "min": torch.lt,
+ "max": torch.gt,
+ }
+
+ def __init__(
+ self,
+ checkpoint_path: Path,
+ monitor: str = "accuracy",
+ mode: str = "auto",
+ min_delta: float = 0.0,
+ ) -> None:
+ """Monitors a quantity that will allow us to determine the best model weights.
+
+ Args:
+ checkpoint_path (Path): Path to the experiment with the checkpoint.
+ monitor (str): Name of the quantity to monitor. Defaults to "accuracy".
+ mode (str): Description of parameter `mode`. Defaults to "auto".
+ min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
+
+ """
+ super().__init__()
+ self.checkpoint_path = checkpoint_path
+ self.monitor = monitor
+ self.mode = mode
+ self.min_delta = torch.tensor(min_delta)
+
+ if mode not in ["auto", "min", "max"]:
+ logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.")
+
+ self.mode = "auto"
+
+ if self.mode == "auto":
+ if "accuracy" in self.monitor:
+ self.mode = "max"
+ else:
+ self.mode = "min"
+ logger.debug(
+ f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}."
+ )
+
+ torch_inf = torch.tensor(np.inf)
+ self.min_delta *= 1 if self.monitor_op == torch.gt else -1
+ self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
+
+ @property
+ def monitor_op(self) -> float:
+ """Returns the comparison method."""
+ return self.mode_dict[self.mode]
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """Saves a checkpoint for the network parameters.
+
+ Args:
+ epoch (int): The current epoch.
+ logs (Dict): The log containing the monitored metrics.
+
+ """
+ current = self.get_monitor_value(logs)
+ if current is None:
+ return
+ if self.monitor_op(current - self.min_delta, self.best_score):
+ self.best_score = current
+ is_best = True
+ else:
+ is_best = False
+
+ self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor)
+
+ def get_monitor_value(self, logs: Dict) -> Union[float, None]:
+ """Extracts the monitored value."""
+ monitor_value = logs.get(self.monitor)
+ if monitor_value is None:
+ logger.warning(
+ f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available"
+ + f" metrics are: {','.join(list(logs.keys()))}"
+ )
+ return None
+ return monitor_value
diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py
index ba2226a..bb41d2d 100644
--- a/src/training/trainer/callbacks/lr_schedulers.py
+++ b/src/training/trainer/callbacks/lr_schedulers.py
@@ -1,6 +1,7 @@
"""Callbacks for learning rate schedulers."""
from typing import Callable, Dict, List, Optional, Type
+from torch.optim.swa_utils import update_bn
from training.trainer.callbacks import Callback
from text_recognizer.models import Model
@@ -95,3 +96,54 @@ class OneCycleLR(Callback):
def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
self.lr_scheduler.step()
+
+
+class CosineAnnealingLR(Callback):
+ """Callback for Cosine Annealing."""
+
+ def __init__(self) -> None:
+ """Initializes the callback."""
+ super().__init__()
+ self.lr_scheduler = None
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and lr scheduler."""
+ self.model = model
+ self.lr_scheduler = self.model.lr_scheduler
+
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+ """Takes a step at the end of every epoch."""
+ self.lr_scheduler.step()
+
+
+class SWA(Callback):
+ """Stochastic Weight Averaging callback."""
+
+ def __init__(self) -> None:
+ """Initializes the callback."""
+ super().__init__()
+ self.swa_scheduler = None
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and lr scheduler."""
+ self.model = model
+ self.swa_start = self.model.swa_start
+ self.swa_scheduler = self.model.lr_scheduler
+ self.lr_scheduler = self.model.lr_scheduler
+
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+ """Takes a step at the end of every training batch."""
+ if epoch > self.swa_start:
+ self.model.swa_network.update_parameters(self.model.network)
+ self.swa_scheduler.step()
+ else:
+ self.lr_scheduler.step()
+
+ def on_fit_end(self) -> None:
+ """Update batch norm statistics for the swa model at the end of training."""
+ if self.model.swa_network:
+ update_bn(
+ self.model.val_dataloader(),
+ self.model.swa_network,
+ device=self.model.device,
+ )
diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py
index 1970747..7829fa0 100644
--- a/src/training/trainer/callbacks/progress_bar.py
+++ b/src/training/trainer/callbacks/progress_bar.py
@@ -18,11 +18,11 @@ class ProgressBar(Callback):
def _configure_progress_bar(self) -> None:
"""Configures the tqdm progress bar with custom bar format."""
self.progress_bar = tqdm(
- total=len(self.model.data_loaders["train"]),
- leave=True,
- unit="step",
+ total=len(self.model.train_dataloader()),
+ leave=False,
+ unit="steps",
mininterval=self.log_batch_frequency,
- bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",
+ bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",
)
def _key_abbreviations(self, logs: Dict) -> Dict:
@@ -34,13 +34,16 @@ class ProgressBar(Callback):
return {rename(key): value for key, value in logs.items()}
- def on_fit_begin(self) -> None:
- """Creates a tqdm progress bar."""
- self._configure_progress_bar()
+ # def on_fit_begin(self) -> None:
+ # """Creates a tqdm progress bar."""
+ # self._configure_progress_bar()
def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None:
"""Updates the description with the current epoch."""
- self.progress_bar.reset()
+ if epoch == 1:
+ self._configure_progress_bar()
+ else:
+ self.progress_bar.reset()
self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}")
def on_epoch_end(self, epoch: int, logs: Dict) -> None:
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index e44c745..6643a44 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -2,7 +2,8 @@
from typing import Callable, Dict, List, Optional, Type
import numpy as np
-from torchvision.transforms import Compose, ToTensor
+import torch
+from torchvision.transforms import ToTensor
from training.trainer.callbacks import Callback
import wandb
@@ -50,43 +51,48 @@ class WandbImageLogger(Callback):
self,
example_indices: Optional[List] = None,
num_examples: int = 4,
- transfroms: Optional[Callable] = None,
+ use_transpose: Optional[bool] = False,
) -> None:
"""Initializes the WandbImageLogger with the model to train.
Args:
example_indices (Optional[List]): Indices for validation images. Defaults to None.
num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
- transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to
- None.
+ use_transpose (Optional[bool]): Use transpose on image or not. Defaults to False.
"""
super().__init__()
self.example_indices = example_indices
self.num_examples = num_examples
- self.transfroms = transfroms
- if self.transfroms is None:
- self.transforms = Compose([Transpose()])
+ self.transpose = Transpose() if use_transpose else None
def set_model(self, model: Type[Model]) -> None:
"""Sets the model and extracts validation images from the dataset."""
self.model = model
- data_loader = self.model.data_loaders["val"]
if self.example_indices is None:
self.example_indices = np.random.randint(
- 0, len(data_loader.dataset.data), self.num_examples
+ 0, len(self.model.val_dataset), self.num_examples
)
- self.val_images = data_loader.dataset.data[self.example_indices]
- self.val_targets = data_loader.dataset.targets[self.example_indices].numpy()
+ self.val_images = self.model.val_dataset.dataset.data[self.example_indices]
+ self.val_targets = self.model.val_dataset.dataset.targets[self.example_indices]
+ self.val_targets = self.val_targets.tolist()
def on_epoch_end(self, epoch: int, logs: Dict) -> None:
"""Get network predictions on validation images."""
images = []
for i, image in enumerate(self.val_images):
- image = self.transforms(image)
+ image = self.transpose(image) if self.transpose is not None else image
pred, conf = self.model.predict_on_image(image)
- ground_truth = self.model.mapper(int(self.val_targets[i]))
+ if isinstance(self.val_targets[i], list):
+ ground_truth = "".join(
+ [
+ self.model.mapper(int(target_index))
+ for target_index in self.val_targets[i]
+ ]
+ ).rstrip("_")
+ else:
+ ground_truth = self.val_targets[i]
caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"
images.append(wandb.Image(image, caption=caption))
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
index a75ae8f..b240157 100644
--- a/src/training/trainer/train.py
+++ b/src/training/trainer/train.py
@@ -8,8 +8,9 @@ from loguru import logger
import numpy as np
import torch
from torch import Tensor
+from torch.optim.swa_utils import update_bn
from training.trainer.callbacks import Callback, CallbackList
-from training.trainer.util import RunningAverage
+from training.trainer.util import log_val_metric, RunningAverage
import wandb
from text_recognizer.models import Model
@@ -24,37 +25,55 @@ torch.cuda.manual_seed(4711)
class Trainer:
"""Trainer for training PyTorch models."""
- def __init__(
- self,
- model: Type[Model],
- model_dir: Path,
- train_args: Dict,
- callbacks: CallbackList,
- checkpoint_path: Optional[Path] = None,
- ) -> None:
+ # TODO: proper add teardown?
+
+ def __init__(self, max_epochs: int, callbacks: List[Type[Callback]],) -> None:
"""Initialization of the Trainer.
Args:
- model (Type[Model]): A model object.
- model_dir (Path): Path to the model directory.
- train_args (Dict): The training arguments.
+ max_epochs (int): The maximum number of epochs in the training loop.
callbacks (CallbackList): List of callbacks to be called.
- checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None.
"""
- self.model = model
- self.model_dir = model_dir
- self.checkpoint_path = checkpoint_path
+ # Training arguments.
self.start_epoch = 1
- self.epochs = train_args["epochs"]
+ self.max_epochs = max_epochs
self.callbacks = callbacks
- if self.checkpoint_path is not None:
- self.start_epoch = self.model.load_checkpoint(self.checkpoint_path)
+ # Flag for setting callbacks.
+ self.callbacks_configured = False
+
+ # Model placeholders
+ self.model = None
+
+ def _configure_callbacks(self) -> None:
+ if not self.callbacks_configured:
+ # Instantiate a CallbackList.
+ self.callbacks = CallbackList(self.model, self.callbacks)
+
+ def compute_metrics(
+ self,
+ output: Tensor,
+ targets: Tensor,
+ loss: Tensor,
+ loss_avg: Type[RunningAverage],
+ ) -> Dict:
+ """Computes metrics for output and target pairs."""
+ # Compute metrics.
+ loss = loss.detach().float().item()
+ loss_avg.update(loss)
+ output = output.detach()
+ targets = targets.detach()
+ if self.model.metrics is not None:
+ metrics = {
+ metric: self.model.metrics[metric](output, targets)
+ for metric in self.model.metrics
+ }
+ else:
+ metrics = {}
+ metrics["loss"] = loss
- # Parse the name of the experiment.
- experiment_dir = str(self.model_dir.parents[1]).split("/")
- self.experiment_name = experiment_dir[-2] + "/" + experiment_dir[-1]
+ return metrics
def training_step(
self,
@@ -75,11 +94,12 @@ class Trainer:
output = self.model.network(data)
# Compute the loss.
- loss = self.model.criterion(output, targets)
+ loss = self.model.loss_fn(output, targets)
# Backward pass.
# Clear the previous gradients.
- self.model.optimizer.zero_grad()
+ for p in self.model.network.parameters():
+ p.grad = None
# Compute the gradients.
loss.backward()
@@ -87,15 +107,8 @@ class Trainer:
# Perform updates using calculated gradients.
self.model.optimizer.step()
- # Compute metrics.
- loss_avg.update(loss.item())
- output = output.data.cpu()
- targets = targets.data.cpu()
- metrics = {
- metric: self.model.metrics[metric](output, targets)
- for metric in self.model.metrics
- }
- metrics["loss"] = loss_avg()
+ metrics = self.compute_metrics(output, targets, loss, loss_avg)
+
return metrics
def train(self) -> None:
@@ -106,9 +119,7 @@ class Trainer:
# Running average for the loss.
loss_avg = RunningAverage()
- data_loader = self.model.data_loaders["train"]
-
- for batch, samples in enumerate(data_loader):
+ for batch, samples in enumerate(self.model.train_dataloader()):
self.callbacks.on_train_batch_begin(batch)
metrics = self.training_step(batch, samples, loss_avg)
self.callbacks.on_train_batch_end(batch, logs=metrics)
@@ -119,6 +130,7 @@ class Trainer:
batch: int,
samples: Tuple[Tensor, Tensor],
loss_avg: Type[RunningAverage],
+ use_swa: bool = False,
) -> Dict:
"""Performs the validation step."""
# Pass the tensor to the device for computation.
@@ -130,44 +142,32 @@ class Trainer:
# Forward pass.
# Get the network prediction.
- output = self.model.network(data)
+ # Use SWA if available and using test dataset.
+ if use_swa and self.model.swa_network is None:
+ output = self.model.swa_network(data)
+ else:
+ output = self.model.network(data)
# Compute the loss.
- loss = self.model.criterion(output, targets)
+ loss = self.model.loss_fn(output, targets)
# Compute metrics.
- loss_avg.update(loss.item())
- output = output.data.cpu()
- targets = targets.data.cpu()
- metrics = {
- metric: self.model.metrics[metric](output, targets)
- for metric in self.model.metrics
- }
- metrics["loss"] = loss.item()
+ metrics = self.compute_metrics(output, targets, loss, loss_avg)
return metrics
- def _log_val_metric(self, metrics_mean: Dict, epoch: Optional[int] = None) -> None:
- log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ")
- logger.debug(
- log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())
- )
-
- def validate(self, epoch: Optional[int] = None) -> Dict:
+ def validate(self) -> Dict:
"""Runs the validation loop for one epoch."""
# Set model to eval mode.
self.model.eval()
# Running average for the loss.
- data_loader = self.model.data_loaders["val"]
-
- # Running average for the loss.
loss_avg = RunningAverage()
# Summary for the current eval loop.
summary = []
- for batch, samples in enumerate(data_loader):
+ for batch, samples in enumerate(self.model.val_dataloader()):
self.callbacks.on_validation_batch_begin(batch)
metrics = self.validation_step(batch, samples, loss_avg)
self.callbacks.on_validation_batch_end(batch, logs=metrics)
@@ -178,14 +178,19 @@ class Trainer:
"val_" + metric: np.mean([x[metric] for x in summary])
for metric in summary[0]
}
- self._log_val_metric(metrics_mean, epoch)
return metrics_mean
- def fit(self) -> None:
+ def fit(self, model: Type[Model]) -> None:
"""Runs the training and evaluation loop."""
- logger.debug(f"Running an experiment called {self.experiment_name}.")
+ # Sets model, loads the data, criterion, and optimizers.
+ self.model = model
+ self.model.prepare_data()
+ self.model.configure_model()
+
+ # Configure callbacks.
+ self._configure_callbacks()
# Set start time.
t_start = time.time()
@@ -193,14 +198,15 @@ class Trainer:
self.callbacks.on_fit_begin()
# Run the training loop.
- for epoch in range(self.start_epoch, self.epochs + 1):
+ for epoch in range(self.start_epoch, self.max_epochs + 1):
self.callbacks.on_epoch_begin(epoch)
# Perform one training pass over the training set.
self.train()
# Evaluate the model on the validation set.
- val_metrics = self.validate(epoch)
+ val_metrics = self.validate()
+ log_val_metric(val_metrics, epoch)
self.callbacks.on_epoch_end(epoch, logs=val_metrics)
@@ -214,3 +220,43 @@ class Trainer:
self.callbacks.on_fit_end()
logger.info(f"Training took {t_training:.2f} s.")
+
+ # "Teardown".
+ self.model = None
+
+ def test(self, model: Type[Model]) -> Dict:
+ """Run inference on test data."""
+
+ # Sets model, loads the data, criterion, and optimizers.
+ self.model = model
+ self.model.prepare_data()
+ self.model.configure_model()
+
+ # Configure callbacks.
+ self._configure_callbacks()
+
+ self.model.eval()
+
+ # Check if SWA network is available.
+ use_swa = True if self.model.swa_network is not None else False
+
+ # Running average for the loss.
+ loss_avg = RunningAverage()
+
+ # Summary for the current test loop.
+ summary = []
+
+ for batch, samples in enumerate(self.model.test_dataloader()):
+ metrics = self.validation_step(batch, samples, loss_avg, use_swa)
+ summary.append(metrics)
+
+ # Compute mean of all test metrics.
+ metrics_mean = {
+ "test_" + metric: np.mean([x[metric] for x in summary])
+ for metric in summary[0]
+ }
+
+ # "Teardown".
+ self.model = None
+
+ return metrics_mean
diff --git a/src/training/trainer/util.py b/src/training/trainer/util.py
index 132b2dc..7cf1b45 100644
--- a/src/training/trainer/util.py
+++ b/src/training/trainer/util.py
@@ -1,4 +1,13 @@
"""Utility functions for training neural networks."""
+from typing import Dict, Optional
+
+from loguru import logger
+
+
+def log_val_metric(metrics_mean: Dict, epoch: Optional[int] = None) -> None:
+ """Logging of val metrics to file/terminal."""
+ log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ")
+ logger.debug(log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()))
class RunningAverage: