diff options
Diffstat (limited to 'src/training')
| -rw-r--r-- | src/training/experiments/sample_experiment.yml | 127 | ||||
| -rw-r--r-- | src/training/prepare_experiments.py | 4 | ||||
| -rw-r--r-- | src/training/run_experiment.py | 238 | ||||
| -rw-r--r-- | src/training/run_sweep.py | 86 | ||||
| -rw-r--r-- | src/training/sweep_emnist.yml | 26 | ||||
| -rw-r--r-- | src/training/sweep_emnist_resnet.yml | 50 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/__init__.py | 15 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/base.py | 78 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/checkpoint.py | 95 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/lr_schedulers.py | 52 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/progress_bar.py | 19 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/wandb_callbacks.py | 32 | ||||
| -rw-r--r-- | src/training/trainer/train.py | 170 | ||||
| -rw-r--r-- | src/training/trainer/util.py | 9 | 
14 files changed, 686 insertions, 315 deletions
diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml index b00bd5a..17e220e 100644 --- a/src/training/experiments/sample_experiment.yml +++ b/src/training/experiments/sample_experiment.yml @@ -1,17 +1,20 @@  experiment_group: Sample Experiments  experiments: -    - dataset: EmnistDataset -      dataset_args: -        sample_to_balance: true -        subsample_fraction: null -        transform: null -        target_transform: null -        seed: 4711 -      data_loader_args: -        splits: [train, val] -        shuffle: true -        num_workers: 8 -        cuda: true +    - train_args: +        batch_size: 256 +        max_epochs: 32 +      dataset: +        type: EmnistDataset +        args: +          sample_to_balance: true +          subsample_fraction: null +          transform: null +          target_transform: null +          seed: 4711 +        train_args: +          num_workers: 6 +          train_fraction: 0.8 +        model: CharacterModel        metrics: [accuracy]        # network: MLP @@ -19,65 +22,81 @@ experiments:        #   input_size: 784        #   hidden_size: 512        #   output_size: 80 -      #   num_layers: 3 -      #   dropout_rate: 0 +      #   num_layers: 5 +      #   dropout_rate: 0.2        #   activation_fn: SELU -      network: ResidualNetwork -      network_args: -        in_channels: 1 -        num_classes: 80 -        depths: [2, 1] -        block_sizes: [96, 32] +      network: +        type: ResidualNetwork +        args: +          in_channels: 1 +          num_classes: 80 +          depths: [2, 2] +          block_sizes: [64, 64] +          activation: leaky_relu +          stn: true +      # network: +      #   type: WideResidualNetwork +      #   args: +      #     in_channels: 1 +      #     num_classes: 80 +      #     depth: 10 +      #     num_layers: 3 +      #     width_factor: 4 +      #     dropout_rate: 0.2 +      #     activation: SELU        # network: LeNet        # network_args:        #   output_size: 62        #   activation_fn: GELU -      train_args: -        batch_size: 256 -        epochs: 32 -      criterion: CrossEntropyLoss -      criterion_args: -        weight: null -        ignore_index: -100 -        reduction: mean -      # optimizer: RMSprop -      # optimizer_args: -      #   lr: 1.e-3 -      #   alpha: 0.9 -      #   eps: 1.e-7 -      #   momentum: 0 -      #   weight_decay: 0 -      #   centered: false -      optimizer: AdamW -      optimizer_args: -        lr: 1.e-03 -        betas: [0.9, 0.999] -        eps: 1.e-08 -        # weight_decay: 5.e-4 -        amsgrad: false -      # lr_scheduler: null -      lr_scheduler: OneCycleLR -      lr_scheduler_args: -        max_lr: 1.e-03 -        epochs: 32 -        anneal_strategy: linear -      callbacks: [Checkpoint, ProgressBar, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR] +      criterion: +        type: CrossEntropyLoss +        args: +          weight: null +          ignore_index: -100 +          reduction: mean +      optimizer: +        type: AdamW +        args: +          lr: 1.e-02 +          betas: [0.9, 0.999] +          eps: 1.e-08 +          # weight_decay: 5.e-4 +          amsgrad: false +      # lr_scheduler: +      #   type: OneCycleLR +      #   args: +      #     max_lr: 1.e-03 +      #     epochs: null +      #     anneal_strategy: linear +      lr_scheduler: +        type: CosineAnnealingLR +        args: +          T_max: null +      swa_args: +        start: 2 +        lr: 5.e-2 +      callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, EarlyStopping, SWA] # OneCycleLR]        callback_args:          Checkpoint:            monitor: val_accuracy          ProgressBar: -          epochs: 32 +          epochs: null            log_batch_frequency: 100          EarlyStopping:            monitor: val_loss            min_delta: 0.0 -          patience: 3 +          patience: 5            mode: min          WandbCallback:            log_batch_frequency: 10          WandbImageLogger:            num_examples: 4 -        OneCycleLR: +          use_transpose: true +        # OneCycleLR: +        #   null +        SWA:            null -      verbosity: 1 # 0, 1, 2 +      verbosity: 0 # 0, 1, 2        resume_experiment: null +      test: true +      test_metric: test_accuracy diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py index 4c3f9ba..e00540c 100644 --- a/src/training/prepare_experiments.py +++ b/src/training/prepare_experiments.py @@ -9,14 +9,14 @@ import yaml  def run_experiments(experiments_filename: str) -> None:      """Run experiment from file.""" -    with open(experiments_filename) as f: +    with open(experiments_filename, "r") as f:          experiments_config = yaml.safe_load(f)      num_experiments = len(experiments_config["experiments"])      for index in range(num_experiments):          experiment_config = experiments_config["experiments"][index]          experiment_config["experiment_group"] = experiments_config["experiment_group"] -        cmd = f"python training/run_experiment.py --gpu=-1 --save --experiment_config='{json.dumps(experiment_config)}'" +        cmd = f"python training/run_experiment.py --gpu=-1 --save '{json.dumps(experiment_config)}'"          print(cmd) diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index 8c063ff..4317d66 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -6,18 +6,19 @@ import json  import os  from pathlib import Path  import re -from typing import Callable, Dict, Tuple, Type +from typing import Callable, Dict, List, Tuple, Type  import click  from loguru import logger  import torch  from tqdm import tqdm  from training.gpu_manager import GPUManager -from training.trainer.callbacks import CallbackList +from training.trainer.callbacks import Callback, CallbackList  from training.trainer.train import Trainer  import wandb  import yaml +  from text_recognizer.models import Model @@ -37,10 +38,14 @@ def get_level(experiment_config: Dict) -> int:          return 10 -def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path: +def create_experiment_dir(experiment_config: Dict) -> Path:      """Create new experiment."""      EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) -    experiment_dir = EXPERIMENTS_DIRNAME / model.__name__ +    experiment_dir = EXPERIMENTS_DIRNAME / ( +        f"{experiment_config['model']}_" +        + f"{experiment_config['dataset']['type']}_" +        + f"{experiment_config['network']['type']}" +    )      if experiment_config["resume_experiment"] is None:          experiment = datetime.now().strftime("%m%d_%H%M%S")          logger.debug(f"Creating a new experiment called {experiment}") @@ -54,70 +59,89 @@ def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path:              experiment = experiment_config["resume_experiment"]              if not str(experiment_dir / experiment) in available_experiments:                  raise FileNotFoundError("Experiment does not exist.") -            logger.debug(f"Resuming the experiment {experiment}")      experiment_dir = experiment_dir / experiment -    return experiment_dir +    # Create log and model directories. +    log_dir = experiment_dir / "log" +    model_dir = experiment_dir / "model" + +    return experiment_dir, log_dir, model_dir -def check_args(args: Dict) -> Dict: + +def check_args(args: Dict, train_args: Dict) -> Dict:      """Checks that the arguments are not None.""" +    args = args or {} + +    # I just want to set total epochs in train args, instead of changing all parameter. +    if "epochs" in args and args["epochs"] is None: +        args["epochs"] = train_args["max_epochs"] + +    # For CosineAnnealingLR. +    if "T_max" in args and args["T_max"] is None: +        args["T_max"] = train_args["max_epochs"] +      return args or {}  def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]:      """Loads all modules and arguments."""      # Import the data loader arguments. -    data_loader_args = experiment_config.get("data_loader_args", {})      train_args = experiment_config.get("train_args", {}) -    data_loader_args["batch_size"] = train_args["batch_size"] -    data_loader_args["dataset"] = experiment_config["dataset"] -    data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {}) + +    # Load the dataset module. +    dataset_args = experiment_config.get("dataset", {}) +    dataset_args["train_args"]["batch_size"] = train_args["batch_size"] +    datasets_module = importlib.import_module("text_recognizer.datasets") +    dataset_ = getattr(datasets_module, dataset_args["type"])      # Import the model module and model arguments.      models_module = importlib.import_module("text_recognizer.models")      model_class_ = getattr(models_module, experiment_config["model"])      # Import metrics. -    metric_fns_ = { -        metric: getattr(models_module, metric) -        for metric in experiment_config["metrics"] -    } +    metric_fns_ = ( +        { +            metric: getattr(models_module, metric) +            for metric in experiment_config["metrics"] +        } +        if experiment_config["metrics"] is not None +        else None +    )      # Import network module and arguments.      network_module = importlib.import_module("text_recognizer.networks") -    network_fn_ = getattr(network_module, experiment_config["network"]) -    network_args = experiment_config.get("network_args", {}) +    network_fn_ = getattr(network_module, experiment_config["network"]["type"]) +    network_args = experiment_config["network"].get("args", {})      # Criterion -    criterion_ = getattr(torch.nn, experiment_config["criterion"]) -    criterion_args = experiment_config.get("criterion_args", {}) +    criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"]) +    criterion_args = experiment_config["criterion"].get("args", {}) -    # Optimizer -    optimizer_ = getattr(torch.optim, experiment_config["optimizer"]) -    optimizer_args = experiment_config.get("optimizer_args", {}) - -    # Callbacks -    callback_modules = importlib.import_module("training.trainer.callbacks") -    callbacks = [ -        getattr(callback_modules, callback)( -            **check_args(experiment_config["callback_args"][callback]) -        ) -        for callback in experiment_config["callbacks"] -    ] +    # Optimizers +    optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"]) +    optimizer_args = experiment_config["optimizer"].get("args", {})      # Learning rate scheduler +    lr_scheduler_ = None +    lr_scheduler_args = None      if experiment_config["lr_scheduler"] is not None:          lr_scheduler_ = getattr( -            torch.optim.lr_scheduler, experiment_config["lr_scheduler"] +            torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"] +        ) +        lr_scheduler_args = check_args( +            experiment_config["lr_scheduler"].get("args", {}), train_args          ) -        lr_scheduler_args = experiment_config.get("lr_scheduler_args", {}) + +    # SWA scheduler. +    if "swa_args" in experiment_config: +        swa_args = check_args(experiment_config.get("swa_args", {}), train_args)      else: -        lr_scheduler_ = None -        lr_scheduler_args = None +        swa_args = None      model_args = { -        "data_loader_args": data_loader_args, +        "dataset": dataset_, +        "dataset_args": dataset_args,          "metrics": metric_fns_,          "network_fn": network_fn_,          "network_args": network_args, @@ -127,43 +151,33 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]          "optimizer_args": optimizer_args,          "lr_scheduler": lr_scheduler_,          "lr_scheduler_args": lr_scheduler_args, +        "swa_args": swa_args,      } -    return model_class_, model_args, callbacks - +    return model_class_, model_args -def run_experiment( -    experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False -) -> None: -    """Runs an experiment.""" - -    # Load the modules and model arguments. -    model_class_, model_args, callbacks = load_modules_and_arguments(experiment_config) - -    # Initializes the model with experiment config. -    model = model_class_(**model_args, device=device) -    # Instantiate a CallbackList. -    callbacks = CallbackList(model, callbacks) - -    # Create new experiment. -    experiment_dir = create_experiment_dir(model, experiment_config) +def configure_callbacks(experiment_config: Dict, model_dir: Dict) -> CallbackList: +    """Configure a callback list for trainer.""" +    train_args = experiment_config.get("train_args", {}) -    # Create log and model directories. -    log_dir = experiment_dir / "log" -    model_dir = experiment_dir / "model" +    if "Checkpoint" in experiment_config["callback_args"]: +        experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = model_dir -    # Set the model dir to be able to save checkpoints. -    model.model_dir = model_dir +    # Callbacks +    callback_modules = importlib.import_module("training.trainer.callbacks") +    callbacks = [ +        getattr(callback_modules, callback)( +            **check_args(experiment_config["callback_args"][callback], train_args) +        ) +        for callback in experiment_config["callbacks"] +    ] -    # Get checkpoint path. -    checkpoint_path = model_dir / "last.pt" -    if not checkpoint_path.exists(): -        checkpoint_path = None +    return callbacks -    # Make sure the log directory exists. -    log_dir.mkdir(parents=True, exist_ok=True) +def configure_logger(experiment_config: Dict, log_dir: Path) -> None: +    """Configure the loguru logger for output to terminal and disk."""      # Have to remove default logger to get tqdm to work properly.      logger.remove() @@ -176,13 +190,50 @@ def run_experiment(          format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",      ) -    if "cuda" in device: -        gpu_index = re.sub("[^0-9]+", "", device) -        logger.info( -            f"Running experiment with config {experiment_config} on GPU {gpu_index}" -        ) -    else: -        logger.info(f"Running experiment with config {experiment_config} on CPU") + +def save_config(experiment_dir: Path, experiment_config: Dict) -> None: +    """Copy config to experiment directory.""" +    config_path = experiment_dir / "config.yml" +    with open(str(config_path), "w") as f: +        yaml.dump(experiment_config, f) + + +def load_from_checkpoint(model: Type[Model], log_dir: Path, model_dir: Path) -> None: +    """If checkpoint exists, load model weights and optimizers from checkpoint.""" +    # Get checkpoint path. +    checkpoint_path = model_dir / "last.pt" +    if checkpoint_path.exists(): +        logger.info("Loading and resuming training from last checkpoint.") +        model.load_checkpoint(checkpoint_path) + + +def run_experiment( +    experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False +) -> None: +    """Runs an experiment.""" +    logger.info(f"Experiment config: {json.dumps(experiment_config)}") + +    # Create new experiment. +    experiment_dir, log_dir, model_dir = create_experiment_dir(experiment_config) + +    # Make sure the log/model directory exists. +    log_dir.mkdir(parents=True, exist_ok=True) +    model_dir.mkdir(parents=True, exist_ok=True) + +    # Load the modules and model arguments. +    model_class_, model_args = load_modules_and_arguments(experiment_config) + +    # Initializes the model with experiment config. +    model = model_class_(**model_args, device=device) + +    callbacks = configure_callbacks(experiment_config, model_dir) + +    # Setup logger. +    configure_logger(experiment_config, log_dir) + +    # Load from checkpoint if resuming an experiment. +    if experiment_config["resume_experiment"] is not None: +        load_from_checkpoint(model, log_dir, model_dir)      logger.info(f"The class mapping is {model.mapping}") @@ -193,9 +244,6 @@ def run_experiment(          # Lets W&B save the model and track the gradients and optional parameters.          wandb.watch(model.network) -    # Pŕints a summary of the network in terminal. -    model.summary() -      experiment_config["train_args"] = {          **DEFAULT_TRAIN_ARGS,          **experiment_config.get("train_args", {}), @@ -208,41 +256,41 @@ def run_experiment(      experiment_config["device"] = device      # Save the config used in the experiment folder. -    config_path = experiment_dir / "config.yml" -    with open(str(config_path), "w") as f: -        yaml.dump(experiment_config, f) +    save_config(experiment_dir, experiment_config) -    # Train the model. +    # Load trainer.      trainer = Trainer( -        model=model, -        model_dir=model_dir, -        train_args=experiment_config["train_args"], -        callbacks=callbacks, -        checkpoint_path=checkpoint_path, +        max_epochs=experiment_config["train_args"]["max_epochs"], callbacks=callbacks,      ) -    trainer.fit() +    # Train the model. +    trainer.fit(model) -    logger.info("Loading checkpoint with the best weights.") -    model.load_checkpoint(model_dir / "best.pt") +    # Run inference over test set. +    if experiment_config["test"]: +        logger.info("Loading checkpoint with the best weights.") +        model.load_from_checkpoint(model_dir / "best.pt") -    score = trainer.validate() +        logger.info("Running inference on test set.") +        score = trainer.test(model) -    logger.info(f"Validation set evaluation: {score}") +        logger.info(f"Test set evaluation: {score}") -    if use_wandb: -        wandb.log({"validation_metric": score["val_accuracy"]}) +        if use_wandb: +            wandb.log( +                { +                    experiment_config["test_metric"]: score[ +                        experiment_config["test_metric"] +                    ] +                } +            )      if save_weights:          model.save_weights(model_dir)  @click.command() -@click.option( -    "--experiment_config", -    type=str, -    help='Experiment JSON, e.g. \'{"dataloader": "EmnistDataLoader", "model": "CharacterModel", "network": "mlp"}\'', -) +@click.argument("experiment_config",)  @click.option("--gpu", type=int, default=0, help="Provide the index of the GPU to use.")  @click.option(      "--save", diff --git a/src/training/run_sweep.py b/src/training/run_sweep.py index 5c5322a..a578592 100644 --- a/src/training/run_sweep.py +++ b/src/training/run_sweep.py @@ -2,7 +2,91 @@  from ast import literal_eval  import json  import os +from pathlib import Path  import signal  import subprocess  # nosec  import sys -from typing import Tuple +from typing import Dict, List, Tuple + +import click +import yaml + +EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" + + +def load_config() -> Dict: +    """Load base hyperparameter config.""" +    with open(str(EXPERIMENTS_DIRNAME / "default_config_emnist.yml"), "r") as f: +        default_config = yaml.safe_load(f) +    return default_config + + +def args_to_json( +    default_config: dict, preserve_args: tuple = ("gpu", "save") +) -> Tuple[dict, list]: +    """Convert command line arguments to nested config values. + +    i.e. run_sweep.py --dataset_args.foo=1.7 +    { +        "dataset_args": { +            "foo": 1.7 +        } +    } + +    Args: +        default_config (dict): The base config used for every experiment. +        preserve_args (tuple): Arguments preserved for all runs. Defaults to ("gpu", "save"). + +    Returns: +        Tuple[dict, list]: Tuple of config dictionary and list of arguments. + +    """ + +    args = [] +    config = default_config.copy() +    key, val = None, None +    for arg in sys.argv[1:]: +        if "=" in arg: +            key, val = arg.split("=") +        elif key: +            val = arg +        else: +            key = arg +        if key and val: +            parsed_key = key.lstrip("-").split(".") +            if parsed_key[0] in preserve_args: +                args.append("--{}={}".format(parsed_key[0], val)) +            else: +                nested = config +                for level in parsed_key[:-1]: +                    nested[level] = config.get(level, {}) +                    nested = nested[level] +                try: +                    # Convert numerics to floats / ints +                    val = literal_eval(val) +                except ValueError: +                    pass +                nested[parsed_key[-1]] = val +            key, val = None, None +    return config, args + + +def main() -> None: +    """Runs a W&B sweep.""" +    default_config = load_config() +    config, args = args_to_json(default_config) +    env = { +        k: v for k, v in os.environ.items() if k not in ("WANDB_PROGRAM", "WANDB_ARGS") +    } +    # pylint: disable=subprocess-popen-preexec-fn +    run = subprocess.Popen( +        ["python", "training/run_experiment.py", *args, json.dumps(config)], +        env=env, +        preexec_fn=os.setsid, +    )  # nosec +    signal.signal(signal.SIGTERM, lambda *args: run.terminate()) +    run.wait() + + +if __name__ == "__main__": +    main() diff --git a/src/training/sweep_emnist.yml b/src/training/sweep_emnist.yml new file mode 100644 index 0000000..48d7261 --- /dev/null +++ b/src/training/sweep_emnist.yml @@ -0,0 +1,26 @@ +program: training/run_sweep.py +method: bayes +metric: +  name: val_loss +  goal: minimize +parameters: +  dataset: +    value: EmnistDataset +  model: +    value: CharacterModel +  network: +    value: MLP +  network_args.hidden_size: +    values: [128, 256] +  network_args.dropout_rate: +    values: [0.2, 0.4] +  network_args.num_layers: +    values: [3, 6] +  optimizer_args.lr: +    values: [1.e-1, 1.e-5] +  lr_scheduler_args.max_lr: +    values: [1.0e-1, 1.0e-5] +  train_args.batch_size: +    values: [64, 128] +  train_args.epochs: +    value: 5 diff --git a/src/training/sweep_emnist_resnet.yml b/src/training/sweep_emnist_resnet.yml new file mode 100644 index 0000000..19a3040 --- /dev/null +++ b/src/training/sweep_emnist_resnet.yml @@ -0,0 +1,50 @@ +program: training/run_sweep.py +method: bayes +metric: +  name: val_accuracy +  goal: maximize +parameters: +  dataset: +    value: EmnistDataset +  model: +    value: CharacterModel +  network: +    value: ResidualNetwork +  network_args.block_sizes: +    distribution: q_uniform +    min: 16 +    max: 256 +    q: 8 +  network_args.depths: +    distribution: int_uniform +    min: 1 +    max: 3 +  network_args.levels: +      distribution: int_uniform +      min: 1 +      max: 2 +  network_args.activation: +    distribution: categorical +    values: +      - gelu +      - leaky_relu +      - relu +      - selu +  optimizer_args.lr: +    distribution: uniform +    min: 1.e-5 +    max: 1.e-1 +  lr_scheduler_args.max_lr: +    distribution: uniform +    min: 1.e-5 +    max: 1.e-1 +  train_args.batch_size: +    distribution: q_uniform +    min: 32 +    max: 256 +    q: 8 +  train_args.epochs: +    value: 5 +early_terminate: +   type: hyperband +   min_iter: 2 diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py index 5942276..c81e4bf 100644 --- a/src/training/trainer/callbacks/__init__.py +++ b/src/training/trainer/callbacks/__init__.py @@ -1,7 +1,16 @@  """The callback modules used in the training script.""" -from .base import Callback, CallbackList, Checkpoint +from .base import Callback, CallbackList +from .checkpoint import Checkpoint  from .early_stopping import EarlyStopping -from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR +from .lr_schedulers import ( +    CosineAnnealingLR, +    CyclicLR, +    MultiStepLR, +    OneCycleLR, +    ReduceLROnPlateau, +    StepLR, +    SWA, +)  from .progress_bar import ProgressBar  from .wandb_callbacks import WandbCallback, WandbImageLogger @@ -9,6 +18,7 @@ __all__ = [      "Callback",      "CallbackList",      "Checkpoint", +    "CosineAnnealingLR",      "EarlyStopping",      "WandbCallback",      "WandbImageLogger", @@ -18,4 +28,5 @@ __all__ = [      "ProgressBar",      "ReduceLROnPlateau",      "StepLR", +    "SWA",  ] diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py index 8df94f3..8c7b085 100644 --- a/src/training/trainer/callbacks/base.py +++ b/src/training/trainer/callbacks/base.py @@ -168,81 +168,3 @@ class CallbackList:      def __iter__(self) -> iter:          """Iter function for callback list."""          return iter(self._callbacks) - - -class Checkpoint(Callback): -    """Saving model parameters at the end of each epoch.""" - -    mode_dict = { -        "min": torch.lt, -        "max": torch.gt, -    } - -    def __init__( -        self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0 -    ) -> None: -        """Monitors a quantity that will allow us to determine the best model weights. - -        Args: -            monitor (str): Name of the quantity to monitor. Defaults to "accuracy". -            mode (str): Description of parameter `mode`. Defaults to "auto". -            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. - -        """ -        super().__init__() -        self.monitor = monitor -        self.mode = mode -        self.min_delta = torch.tensor(min_delta) - -        if mode not in ["auto", "min", "max"]: -            logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") - -            self.mode = "auto" - -        if self.mode == "auto": -            if "accuracy" in self.monitor: -                self.mode = "max" -            else: -                self.mode = "min" -            logger.debug( -                f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." -            ) - -        torch_inf = torch.tensor(np.inf) -        self.min_delta *= 1 if self.monitor_op == torch.gt else -1 -        self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf - -    @property -    def monitor_op(self) -> float: -        """Returns the comparison method.""" -        return self.mode_dict[self.mode] - -    def on_epoch_end(self, epoch: int, logs: Dict) -> None: -        """Saves a checkpoint for the network parameters. - -        Args: -            epoch (int): The current epoch. -            logs (Dict): The log containing the monitored metrics. - -        """ -        current = self.get_monitor_value(logs) -        if current is None: -            return -        if self.monitor_op(current - self.min_delta, self.best_score): -            self.best_score = current -            is_best = True -        else: -            is_best = False - -        self.model.save_checkpoint(is_best, epoch, self.monitor) - -    def get_monitor_value(self, logs: Dict) -> Union[float, None]: -        """Extracts the monitored value.""" -        monitor_value = logs.get(self.monitor) -        if monitor_value is None: -            logger.warning( -                f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" -                + f"metrics are: {','.join(list(logs.keys()))}" -            ) -            return None -        return monitor_value diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py new file mode 100644 index 0000000..6fe06d3 --- /dev/null +++ b/src/training/trainer/callbacks/checkpoint.py @@ -0,0 +1,95 @@ +"""Callback checkpoint for training models.""" +from enum import Enum +from pathlib import Path +from typing import Callable, Dict, List, Optional, Type, Union + +from loguru import logger +import numpy as np +import torch +from training.trainer.callbacks import Callback + +from text_recognizer.models import Model + + +class Checkpoint(Callback): +    """Saving model parameters at the end of each epoch.""" + +    mode_dict = { +        "min": torch.lt, +        "max": torch.gt, +    } + +    def __init__( +        self, +        checkpoint_path: Path, +        monitor: str = "accuracy", +        mode: str = "auto", +        min_delta: float = 0.0, +    ) -> None: +        """Monitors a quantity that will allow us to determine the best model weights. + +        Args: +            checkpoint_path (Path): Path to the experiment with the checkpoint. +            monitor (str): Name of the quantity to monitor. Defaults to "accuracy". +            mode (str): Description of parameter `mode`. Defaults to "auto". +            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. + +        """ +        super().__init__() +        self.checkpoint_path = checkpoint_path +        self.monitor = monitor +        self.mode = mode +        self.min_delta = torch.tensor(min_delta) + +        if mode not in ["auto", "min", "max"]: +            logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") + +            self.mode = "auto" + +        if self.mode == "auto": +            if "accuracy" in self.monitor: +                self.mode = "max" +            else: +                self.mode = "min" +            logger.debug( +                f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." +            ) + +        torch_inf = torch.tensor(np.inf) +        self.min_delta *= 1 if self.monitor_op == torch.gt else -1 +        self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf + +    @property +    def monitor_op(self) -> float: +        """Returns the comparison method.""" +        return self.mode_dict[self.mode] + +    def on_epoch_end(self, epoch: int, logs: Dict) -> None: +        """Saves a checkpoint for the network parameters. + +        Args: +            epoch (int): The current epoch. +            logs (Dict): The log containing the monitored metrics. + +        """ +        current = self.get_monitor_value(logs) +        if current is None: +            return +        if self.monitor_op(current - self.min_delta, self.best_score): +            self.best_score = current +            is_best = True +        else: +            is_best = False + +        self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor) + +    def get_monitor_value(self, logs: Dict) -> Union[float, None]: +        """Extracts the monitored value.""" +        monitor_value = logs.get(self.monitor) +        if monitor_value is None: +            logger.warning( +                f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" +                + f" metrics are: {','.join(list(logs.keys()))}" +            ) +            return None +        return monitor_value diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py index ba2226a..bb41d2d 100644 --- a/src/training/trainer/callbacks/lr_schedulers.py +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -1,6 +1,7 @@  """Callbacks for learning rate schedulers."""  from typing import Callable, Dict, List, Optional, Type +from torch.optim.swa_utils import update_bn  from training.trainer.callbacks import Callback  from text_recognizer.models import Model @@ -95,3 +96,54 @@ class OneCycleLR(Callback):      def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:          """Takes a step at the end of every training batch."""          self.lr_scheduler.step() + + +class CosineAnnealingLR(Callback): +    """Callback for Cosine Annealing.""" + +    def __init__(self) -> None: +        """Initializes the callback.""" +        super().__init__() +        self.lr_scheduler = None + +    def set_model(self, model: Type[Model]) -> None: +        """Sets the model and lr scheduler.""" +        self.model = model +        self.lr_scheduler = self.model.lr_scheduler + +    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Takes a step at the end of every epoch.""" +        self.lr_scheduler.step() + + +class SWA(Callback): +    """Stochastic Weight Averaging callback.""" + +    def __init__(self) -> None: +        """Initializes the callback.""" +        super().__init__() +        self.swa_scheduler = None + +    def set_model(self, model: Type[Model]) -> None: +        """Sets the model and lr scheduler.""" +        self.model = model +        self.swa_start = self.model.swa_start +        self.swa_scheduler = self.model.lr_scheduler +        self.lr_scheduler = self.model.lr_scheduler + +    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Takes a step at the end of every training batch.""" +        if epoch > self.swa_start: +            self.model.swa_network.update_parameters(self.model.network) +            self.swa_scheduler.step() +        else: +            self.lr_scheduler.step() + +    def on_fit_end(self) -> None: +        """Update batch norm statistics for the swa model at the end of training.""" +        if self.model.swa_network: +            update_bn( +                self.model.val_dataloader(), +                self.model.swa_network, +                device=self.model.device, +            ) diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py index 1970747..7829fa0 100644 --- a/src/training/trainer/callbacks/progress_bar.py +++ b/src/training/trainer/callbacks/progress_bar.py @@ -18,11 +18,11 @@ class ProgressBar(Callback):      def _configure_progress_bar(self) -> None:          """Configures the tqdm progress bar with custom bar format."""          self.progress_bar = tqdm( -            total=len(self.model.data_loaders["train"]), -            leave=True, -            unit="step", +            total=len(self.model.train_dataloader()), +            leave=False, +            unit="steps",              mininterval=self.log_batch_frequency, -            bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", +            bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",          )      def _key_abbreviations(self, logs: Dict) -> Dict: @@ -34,13 +34,16 @@ class ProgressBar(Callback):          return {rename(key): value for key, value in logs.items()} -    def on_fit_begin(self) -> None: -        """Creates a tqdm progress bar.""" -        self._configure_progress_bar() +    # def on_fit_begin(self) -> None: +    #     """Creates a tqdm progress bar.""" +    #     self._configure_progress_bar()      def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None:          """Updates the description with the current epoch.""" -        self.progress_bar.reset() +        if epoch == 1: +            self._configure_progress_bar() +        else: +            self.progress_bar.reset()          self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}")      def on_epoch_end(self, epoch: int, logs: Dict) -> None: diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index e44c745..6643a44 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -2,7 +2,8 @@  from typing import Callable, Dict, List, Optional, Type  import numpy as np -from torchvision.transforms import Compose, ToTensor +import torch +from torchvision.transforms import ToTensor  from training.trainer.callbacks import Callback  import wandb @@ -50,43 +51,48 @@ class WandbImageLogger(Callback):          self,          example_indices: Optional[List] = None,          num_examples: int = 4, -        transfroms: Optional[Callable] = None, +        use_transpose: Optional[bool] = False,      ) -> None:          """Initializes the WandbImageLogger with the model to train.          Args:              example_indices (Optional[List]): Indices for validation images. Defaults to None.              num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. -            transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to -                None. +            use_transpose (Optional[bool]): Use transpose on image or not. Defaults to False.          """          super().__init__()          self.example_indices = example_indices          self.num_examples = num_examples -        self.transfroms = transfroms -        if self.transfroms is None: -            self.transforms = Compose([Transpose()]) +        self.transpose = Transpose() if use_transpose else None      def set_model(self, model: Type[Model]) -> None:          """Sets the model and extracts validation images from the dataset."""          self.model = model -        data_loader = self.model.data_loaders["val"]          if self.example_indices is None:              self.example_indices = np.random.randint( -                0, len(data_loader.dataset.data), self.num_examples +                0, len(self.model.val_dataset), self.num_examples              ) -        self.val_images = data_loader.dataset.data[self.example_indices] -        self.val_targets = data_loader.dataset.targets[self.example_indices].numpy() +        self.val_images = self.model.val_dataset.dataset.data[self.example_indices] +        self.val_targets = self.model.val_dataset.dataset.targets[self.example_indices] +        self.val_targets = self.val_targets.tolist()      def on_epoch_end(self, epoch: int, logs: Dict) -> None:          """Get network predictions on validation images."""          images = []          for i, image in enumerate(self.val_images): -            image = self.transforms(image) +            image = self.transpose(image) if self.transpose is not None else image              pred, conf = self.model.predict_on_image(image) -            ground_truth = self.model.mapper(int(self.val_targets[i])) +            if isinstance(self.val_targets[i], list): +                ground_truth = "".join( +                    [ +                        self.model.mapper(int(target_index)) +                        for target_index in self.val_targets[i] +                    ] +                ).rstrip("_") +            else: +                ground_truth = self.val_targets[i]              caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"              images.append(wandb.Image(image, caption=caption)) diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index a75ae8f..b240157 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -8,8 +8,9 @@ from loguru import logger  import numpy as np  import torch  from torch import Tensor +from torch.optim.swa_utils import update_bn  from training.trainer.callbacks import Callback, CallbackList -from training.trainer.util import RunningAverage +from training.trainer.util import log_val_metric, RunningAverage  import wandb  from text_recognizer.models import Model @@ -24,37 +25,55 @@ torch.cuda.manual_seed(4711)  class Trainer:      """Trainer for training PyTorch models.""" -    def __init__( -        self, -        model: Type[Model], -        model_dir: Path, -        train_args: Dict, -        callbacks: CallbackList, -        checkpoint_path: Optional[Path] = None, -    ) -> None: +    # TODO: proper add teardown? + +    def __init__(self, max_epochs: int, callbacks: List[Type[Callback]],) -> None:          """Initialization of the Trainer.          Args: -            model (Type[Model]): A model object. -            model_dir (Path): Path to the model directory. -            train_args (Dict): The training arguments. +            max_epochs (int): The maximum number of epochs in the training loop.              callbacks (CallbackList): List of callbacks to be called. -            checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None.          """ -        self.model = model -        self.model_dir = model_dir -        self.checkpoint_path = checkpoint_path +        # Training arguments.          self.start_epoch = 1 -        self.epochs = train_args["epochs"] +        self.max_epochs = max_epochs          self.callbacks = callbacks -        if self.checkpoint_path is not None: -            self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) +        # Flag for setting callbacks. +        self.callbacks_configured = False + +        # Model placeholders +        self.model = None + +    def _configure_callbacks(self) -> None: +        if not self.callbacks_configured: +            # Instantiate a CallbackList. +            self.callbacks = CallbackList(self.model, self.callbacks) + +    def compute_metrics( +        self, +        output: Tensor, +        targets: Tensor, +        loss: Tensor, +        loss_avg: Type[RunningAverage], +    ) -> Dict: +        """Computes metrics for output and target pairs.""" +        # Compute metrics. +        loss = loss.detach().float().item() +        loss_avg.update(loss) +        output = output.detach() +        targets = targets.detach() +        if self.model.metrics is not None: +            metrics = { +                metric: self.model.metrics[metric](output, targets) +                for metric in self.model.metrics +            } +        else: +            metrics = {} +        metrics["loss"] = loss -        # Parse the name of the experiment. -        experiment_dir = str(self.model_dir.parents[1]).split("/") -        self.experiment_name = experiment_dir[-2] + "/" + experiment_dir[-1] +        return metrics      def training_step(          self, @@ -75,11 +94,12 @@ class Trainer:          output = self.model.network(data)          # Compute the loss. -        loss = self.model.criterion(output, targets) +        loss = self.model.loss_fn(output, targets)          # Backward pass.          # Clear the previous gradients. -        self.model.optimizer.zero_grad() +        for p in self.model.network.parameters(): +            p.grad = None          # Compute the gradients.          loss.backward() @@ -87,15 +107,8 @@ class Trainer:          # Perform updates using calculated gradients.          self.model.optimizer.step() -        # Compute metrics. -        loss_avg.update(loss.item()) -        output = output.data.cpu() -        targets = targets.data.cpu() -        metrics = { -            metric: self.model.metrics[metric](output, targets) -            for metric in self.model.metrics -        } -        metrics["loss"] = loss_avg() +        metrics = self.compute_metrics(output, targets, loss, loss_avg) +          return metrics      def train(self) -> None: @@ -106,9 +119,7 @@ class Trainer:          # Running average for the loss.          loss_avg = RunningAverage() -        data_loader = self.model.data_loaders["train"] - -        for batch, samples in enumerate(data_loader): +        for batch, samples in enumerate(self.model.train_dataloader()):              self.callbacks.on_train_batch_begin(batch)              metrics = self.training_step(batch, samples, loss_avg)              self.callbacks.on_train_batch_end(batch, logs=metrics) @@ -119,6 +130,7 @@ class Trainer:          batch: int,          samples: Tuple[Tensor, Tensor],          loss_avg: Type[RunningAverage], +        use_swa: bool = False,      ) -> Dict:          """Performs the validation step."""          # Pass the tensor to the device for computation. @@ -130,44 +142,32 @@ class Trainer:          # Forward pass.          # Get the network prediction. -        output = self.model.network(data) +        # Use SWA if available and using test dataset. +        if use_swa and self.model.swa_network is None: +            output = self.model.swa_network(data) +        else: +            output = self.model.network(data)          # Compute the loss. -        loss = self.model.criterion(output, targets) +        loss = self.model.loss_fn(output, targets)          # Compute metrics. -        loss_avg.update(loss.item()) -        output = output.data.cpu() -        targets = targets.data.cpu() -        metrics = { -            metric: self.model.metrics[metric](output, targets) -            for metric in self.model.metrics -        } -        metrics["loss"] = loss.item() +        metrics = self.compute_metrics(output, targets, loss, loss_avg)          return metrics -    def _log_val_metric(self, metrics_mean: Dict, epoch: Optional[int] = None) -> None: -        log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ") -        logger.debug( -            log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) -        ) - -    def validate(self, epoch: Optional[int] = None) -> Dict: +    def validate(self) -> Dict:          """Runs the validation loop for one epoch."""          # Set model to eval mode.          self.model.eval()          # Running average for the loss. -        data_loader = self.model.data_loaders["val"] - -        # Running average for the loss.          loss_avg = RunningAverage()          # Summary for the current eval loop.          summary = [] -        for batch, samples in enumerate(data_loader): +        for batch, samples in enumerate(self.model.val_dataloader()):              self.callbacks.on_validation_batch_begin(batch)              metrics = self.validation_step(batch, samples, loss_avg)              self.callbacks.on_validation_batch_end(batch, logs=metrics) @@ -178,14 +178,19 @@ class Trainer:              "val_" + metric: np.mean([x[metric] for x in summary])              for metric in summary[0]          } -        self._log_val_metric(metrics_mean, epoch)          return metrics_mean -    def fit(self) -> None: +    def fit(self, model: Type[Model]) -> None:          """Runs the training and evaluation loop.""" -        logger.debug(f"Running an experiment called {self.experiment_name}.") +        # Sets model, loads the data, criterion, and optimizers. +        self.model = model +        self.model.prepare_data() +        self.model.configure_model() + +        # Configure callbacks. +        self._configure_callbacks()          # Set start time.          t_start = time.time() @@ -193,14 +198,15 @@ class Trainer:          self.callbacks.on_fit_begin()          # Run the training loop. -        for epoch in range(self.start_epoch, self.epochs + 1): +        for epoch in range(self.start_epoch, self.max_epochs + 1):              self.callbacks.on_epoch_begin(epoch)              # Perform one training pass over the training set.              self.train()              # Evaluate the model on the validation set. -            val_metrics = self.validate(epoch) +            val_metrics = self.validate() +            log_val_metric(val_metrics, epoch)              self.callbacks.on_epoch_end(epoch, logs=val_metrics) @@ -214,3 +220,43 @@ class Trainer:          self.callbacks.on_fit_end()          logger.info(f"Training took {t_training:.2f} s.") + +        # "Teardown". +        self.model = None + +    def test(self, model: Type[Model]) -> Dict: +        """Run inference on test data.""" + +        # Sets model, loads the data, criterion, and optimizers. +        self.model = model +        self.model.prepare_data() +        self.model.configure_model() + +        # Configure callbacks. +        self._configure_callbacks() + +        self.model.eval() + +        # Check if SWA network is available. +        use_swa = True if self.model.swa_network is not None else False + +        # Running average for the loss. +        loss_avg = RunningAverage() + +        # Summary for the current test loop. +        summary = [] + +        for batch, samples in enumerate(self.model.test_dataloader()): +            metrics = self.validation_step(batch, samples, loss_avg, use_swa) +            summary.append(metrics) + +        # Compute mean of all test metrics. +        metrics_mean = { +            "test_" + metric: np.mean([x[metric] for x in summary]) +            for metric in summary[0] +        } + +        # "Teardown". +        self.model = None + +        return metrics_mean diff --git a/src/training/trainer/util.py b/src/training/trainer/util.py index 132b2dc..7cf1b45 100644 --- a/src/training/trainer/util.py +++ b/src/training/trainer/util.py @@ -1,4 +1,13 @@  """Utility functions for training neural networks.""" +from typing import Dict, Optional + +from loguru import logger + + +def log_val_metric(metrics_mean: Dict, epoch: Optional[int] = None) -> None: +    """Logging of val metrics to file/terminal.""" +    log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ") +    logger.debug(log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()))  class RunningAverage:  |