From dc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 8 Nov 2020 14:54:44 +0100 Subject: new updates --- src/training/experiments/default_config_emnist.yml | 1 + src/training/experiments/embedding_experiment.yml | 64 ++++++++ src/training/experiments/line_ctc_experiment.yml | 91 ------------ src/training/experiments/sample_experiment.yml | 1 + src/training/prepare_experiments.py | 2 - src/training/run_experiment.py | 161 ++++++++++++++------- src/training/trainer/callbacks/base.py | 20 ++- src/training/trainer/callbacks/checkpoint.py | 6 +- src/training/trainer/callbacks/lr_schedulers.py | 5 +- src/training/trainer/callbacks/wandb_callbacks.py | 34 ++++- .../trainer/population_based_training/__init__.py | 1 - .../population_based_training.py | 1 - src/training/trainer/train.py | 42 +++++- 13 files changed, 266 insertions(+), 163 deletions(-) create mode 100644 src/training/experiments/embedding_experiment.yml delete mode 100644 src/training/experiments/line_ctc_experiment.yml delete mode 100644 src/training/trainer/population_based_training/__init__.py delete mode 100644 src/training/trainer/population_based_training/population_based_training.py (limited to 'src/training') diff --git a/src/training/experiments/default_config_emnist.yml b/src/training/experiments/default_config_emnist.yml index 12a0a9d..bf2ed0a 100644 --- a/src/training/experiments/default_config_emnist.yml +++ b/src/training/experiments/default_config_emnist.yml @@ -66,4 +66,5 @@ callback_args: null verbosity: 1 # 0, 1, 2 resume_experiment: null +train: true validation_metric: val_accuracy diff --git a/src/training/experiments/embedding_experiment.yml b/src/training/experiments/embedding_experiment.yml new file mode 100644 index 0000000..1e5f941 --- /dev/null +++ b/src/training/experiments/embedding_experiment.yml @@ -0,0 +1,64 @@ +experiment_group: Embedding Experiments +experiments: + - train_args: + transformer_model: false + batch_size: &batch_size 256 + max_epochs: &max_epochs 32 + input_shape: [[1, 28, 28]] + dataset: + type: EmnistDataset + args: + sample_to_balance: true + subsample_fraction: null + transform: null + target_transform: null + seed: 4711 + train_args: + num_workers: 8 + train_fraction: 0.85 + batch_size: *batch_size + model: CharacterModel + metrics: [] + network: + type: DenseNet + args: + growth_rate: 4 + block_config: [4, 4] + in_channels: 1 + base_channels: 24 + num_classes: 128 + bn_size: 4 + dropout_rate: 0.1 + classifier: true + activation: elu + criterion: + type: EmbeddingLoss + args: + margin: 0.2 + type_of_triplets: semihard + optimizer: + type: AdamW + args: + lr: 1.e-02 + betas: [0.9, 0.999] + eps: 1.e-08 + weight_decay: 5.e-4 + amsgrad: false + lr_scheduler: + type: CosineAnnealingLR + args: + T_max: *max_epochs + callbacks: [Checkpoint, ProgressBar, WandbCallback] + callback_args: + Checkpoint: + monitor: val_loss + mode: min + ProgressBar: + epochs: *max_epochs + WandbCallback: + log_batch_frequency: 10 + verbosity: 1 # 0, 1, 2 + resume_experiment: null + train: true + test: true + test_metric: mean_average_precision_at_r diff --git a/src/training/experiments/line_ctc_experiment.yml b/src/training/experiments/line_ctc_experiment.yml deleted file mode 100644 index 432d1cc..0000000 --- a/src/training/experiments/line_ctc_experiment.yml +++ /dev/null @@ -1,91 +0,0 @@ -experiment_group: Lines Experiments -experiments: - - train_args: - batch_size: 42 - max_epochs: &max_epochs 32 - dataset: - type: IamLinesDataset - args: - subsample_fraction: null - transform: null - target_transform: null - train_args: - num_workers: 8 - train_fraction: 0.85 - model: LineCTCModel - metrics: [cer, wer] - network: - type: LineRecurrentNetwork - args: - backbone: ResidualNetwork - backbone_args: - in_channels: 1 - num_classes: 64 # Embedding - depths: [2,2] - block_sizes: [32,64] - activation: selu - stn: false - # encoder: ResidualNetwork - # encoder_args: - # pretrained: training/experiments/CharacterModel_EmnistDataset_ResidualNetwork/0917_203601/model/best.pt - # freeze: false - flatten: false - input_size: 64 - hidden_size: 64 - bidirectional: true - num_layers: 2 - num_classes: 80 - patch_size: [28, 18] - stride: [1, 4] - criterion: - type: CTCLoss - args: - blank: 79 - optimizer: - type: AdamW - args: - lr: 1.e-02 - betas: [0.9, 0.999] - eps: 1.e-08 - weight_decay: 5.e-4 - amsgrad: false - lr_scheduler: - type: OneCycleLR - args: - max_lr: 1.e-02 - epochs: *max_epochs - anneal_strategy: cos - pct_start: 0.475 - cycle_momentum: true - base_momentum: 0.85 - max_momentum: 0.9 - div_factor: 10 - final_div_factor: 10000 - interval: step - # lr_scheduler: - # type: CosineAnnealingLR - # args: - # T_max: *max_epochs - swa_args: - start: 24 - lr: 5.e-2 - callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger] # EarlyStopping] - callback_args: - Checkpoint: - monitor: val_loss - mode: min - ProgressBar: - epochs: *max_epochs - # EarlyStopping: - # monitor: val_loss - # min_delta: 0.0 - # patience: 10 - # mode: min - WandbCallback: - log_batch_frequency: 10 - WandbImageLogger: - num_examples: 6 - verbosity: 1 # 0, 1, 2 - resume_experiment: null - test: true - test_metric: test_cer diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml index 8664a15..a073a87 100644 --- a/src/training/experiments/sample_experiment.yml +++ b/src/training/experiments/sample_experiment.yml @@ -95,5 +95,6 @@ experiments: use_transpose: true verbosity: 0 # 0, 1, 2 resume_experiment: null + train: true test: true test_metric: test_accuracy diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py index e00540c..6e20bcd 100644 --- a/src/training/prepare_experiments.py +++ b/src/training/prepare_experiments.py @@ -1,9 +1,7 @@ """Run a experiment from a config file.""" import json -from subprocess import run import click -from loguru import logger import yaml diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index a347d9f..0510d5c 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -6,12 +6,15 @@ import json import os from pathlib import Path import re -from typing import Callable, Dict, List, Tuple, Type +from typing import Callable, Dict, List, Optional, Tuple, Type +import warnings +import adabelief_pytorch import click from loguru import logger import numpy as np import torch +from torchsummary import summary from tqdm import tqdm from training.gpu_manager import GPUManager from training.trainer.callbacks import Callback, CallbackList @@ -21,26 +24,23 @@ import yaml from text_recognizer.models import Model -from text_recognizer.networks import losses - +from text_recognizer.networks import loss as custom_loss_module EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" -CUSTOM_LOSSES = ["EmbeddingLoss"] DEFAULT_TRAIN_ARGS = {"batch_size": 64, "epochs": 16} -def get_level(experiment_config: Dict) -> int: +def _get_level(verbose: int) -> int: """Sets the logger level.""" - if experiment_config["verbosity"] == 0: - return 40 - elif experiment_config["verbosity"] == 1: - return 20 - else: - return 10 + levels = {0: 40, 1: 20, 2: 10} + verbose = verbose if verbose <= 2 else 2 + return levels[verbose] -def create_experiment_dir(experiment_config: Dict) -> Path: +def _create_experiment_dir( + experiment_config: Dict, checkpoint: Optional[str] = None +) -> Path: """Create new experiment.""" EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) experiment_dir = EXPERIMENTS_DIRNAME / ( @@ -48,19 +48,21 @@ def create_experiment_dir(experiment_config: Dict) -> Path: + f"{experiment_config['dataset']['type']}_" + f"{experiment_config['network']['type']}" ) - if experiment_config["resume_experiment"] is None: + + if checkpoint is None: experiment = datetime.now().strftime("%m%d_%H%M%S") logger.debug(f"Creating a new experiment called {experiment}") else: available_experiments = glob(str(experiment_dir) + "/*") available_experiments.sort() - if experiment_config["resume_experiment"] == "last": + if checkpoint == "last": experiment = available_experiments[-1] logger.debug(f"Resuming the latest experiment {experiment}") else: - experiment = experiment_config["resume_experiment"] + experiment = checkpoint if not str(experiment_dir / experiment) in available_experiments: raise FileNotFoundError("Experiment does not exist.") + logger.debug(f"Resuming the from experiment {checkpoint}") experiment_dir = experiment_dir / experiment @@ -71,14 +73,10 @@ def create_experiment_dir(experiment_config: Dict) -> Path: return experiment_dir, log_dir, model_dir -def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]: +def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dict]: """Loads all modules and arguments.""" - # Import the data loader arguments. - train_args = experiment_config.get("train_args", {}) - # Load the dataset module. dataset_args = experiment_config.get("dataset", {}) - dataset_args["train_args"]["batch_size"] = train_args["batch_size"] datasets_module = importlib.import_module("text_recognizer.datasets") dataset_ = getattr(datasets_module, dataset_args["type"]) @@ -102,21 +100,24 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] network_args = experiment_config["network"].get("args", {}) # Criterion - if experiment_config["criterion"]["type"] in CUSTOM_LOSSES: - criterion_ = getattr(losses, experiment_config["criterion"]["type"]) - criterion_args = experiment_config["criterion"].get("args", {}) + if experiment_config["criterion"]["type"] in custom_loss_module.__all__: + criterion_ = getattr(custom_loss_module, experiment_config["criterion"]["type"]) else: criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"]) - criterion_args = experiment_config["criterion"].get("args", {}) + criterion_args = experiment_config["criterion"].get("args", {}) or {} # Optimizers - optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"]) + if experiment_config["optimizer"]["type"] == "AdaBelief": + warnings.filterwarnings("ignore", category=UserWarning) + optimizer_ = getattr(adabelief_pytorch, experiment_config["optimizer"]["type"]) + else: + optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"]) optimizer_args = experiment_config["optimizer"].get("args", {}) # Learning rate scheduler lr_scheduler_ = None lr_scheduler_args = None - if experiment_config["lr_scheduler"] is not None: + if "lr_scheduler" in experiment_config: lr_scheduler_ = getattr( torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"] ) @@ -146,10 +147,12 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] return model_class_, model_args -def configure_callbacks(experiment_config: Dict, model_dir: Dict) -> CallbackList: +def _configure_callbacks(experiment_config: Dict, model_dir: Path) -> CallbackList: """Configure a callback list for trainer.""" if "Checkpoint" in experiment_config["callback_args"]: - experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = model_dir + experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = str( + model_dir + ) # Initializes callbacks. callback_modules = importlib.import_module("training.trainer.callbacks") @@ -161,13 +164,13 @@ def configure_callbacks(experiment_config: Dict, model_dir: Dict) -> CallbackLis return callbacks -def configure_logger(experiment_config: Dict, log_dir: Path) -> None: +def _configure_logger(log_dir: Path, verbose: int = 0) -> None: """Configure the loguru logger for output to terminal and disk.""" # Have to remove default logger to get tqdm to work properly. logger.remove() # Fetch verbosity level. - level = get_level(experiment_config) + level = _get_level(verbose) logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level) logger.add( @@ -176,20 +179,29 @@ def configure_logger(experiment_config: Dict, log_dir: Path) -> None: ) -def save_config(experiment_dir: Path, experiment_config: Dict) -> None: +def _save_config(experiment_dir: Path, experiment_config: Dict) -> None: """Copy config to experiment directory.""" config_path = experiment_dir / "config.yml" with open(str(config_path), "w") as f: yaml.dump(experiment_config, f) -def load_from_checkpoint(model: Type[Model], log_dir: Path, model_dir: Path) -> None: +def _load_from_checkpoint( + model: Type[Model], model_dir: Path, pretrained_weights: str = None, +) -> None: """If checkpoint exists, load model weights and optimizers from checkpoint.""" # Get checkpoint path. - checkpoint_path = model_dir / "last.pt" + if pretrained_weights is not None: + logger.info(f"Loading weights from {pretrained_weights}.") + checkpoint_path = ( + EXPERIMENTS_DIRNAME / Path(pretrained_weights) / "model" / "best.pt" + ) + else: + logger.info(f"Loading weights from {model_dir}.") + checkpoint_path = model_dir / "last.pt" if checkpoint_path.exists(): - logger.info("Loading and resuming training from last checkpoint.") - model.load_checkpoint(checkpoint_path) + logger.info("Loading and resuming training from checkpoint.") + model.load_from_checkpoint(checkpoint_path) def evaluate_embedding(model: Type[Model]) -> Dict: @@ -217,38 +229,50 @@ def evaluate_embedding(model: Type[Model]) -> Dict: def run_experiment( - experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False + experiment_config: Dict, + save_weights: bool, + device: str, + use_wandb: bool, + train: bool, + test: bool, + verbose: int = 0, + checkpoint: Optional[str] = None, + pretrained_weights: Optional[str] = None, ) -> None: """Runs an experiment.""" logger.info(f"Experiment config: {json.dumps(experiment_config)}") # Create new experiment. - experiment_dir, log_dir, model_dir = create_experiment_dir(experiment_config) + experiment_dir, log_dir, model_dir = _create_experiment_dir( + experiment_config, checkpoint + ) # Make sure the log/model directory exists. log_dir.mkdir(parents=True, exist_ok=True) model_dir.mkdir(parents=True, exist_ok=True) # Load the modules and model arguments. - model_class_, model_args = load_modules_and_arguments(experiment_config) + model_class_, model_args = _load_modules_and_arguments(experiment_config) # Initializes the model with experiment config. model = model_class_(**model_args, device=device) - callbacks = configure_callbacks(experiment_config, model_dir) + callbacks = _configure_callbacks(experiment_config, model_dir) # Setup logger. - configure_logger(experiment_config, log_dir) + _configure_logger(log_dir, verbose) # Load from checkpoint if resuming an experiment. - if experiment_config["resume_experiment"] is not None: - load_from_checkpoint(model, log_dir, model_dir) + resume = False + if checkpoint is not None or pretrained_weights is not None: + resume = True + _load_from_checkpoint(model, model_dir, pretrained_weights) logger.info(f"The class mapping is {model.mapping}") # Initializes Weights & Biases if use_wandb: - wandb.init(project="text-recognizer", config=experiment_config) + wandb.init(project="text-recognizer", config=experiment_config, resume=resume) # Lets W&B save the model and track the gradients and optional parameters. wandb.watch(model.network) @@ -265,23 +289,30 @@ def run_experiment( experiment_config["device"] = device # Save the config used in the experiment folder. - save_config(experiment_dir, experiment_config) + _save_config(experiment_dir, experiment_config) + + # Prints a summary of the network in terminal. + model.summary(experiment_config["train_args"]["input_shape"]) # Load trainer. trainer = Trainer( - max_epochs=experiment_config["train_args"]["max_epochs"], callbacks=callbacks, + max_epochs=experiment_config["train_args"]["max_epochs"], + callbacks=callbacks, + transformer_model=experiment_config["train_args"]["transformer_model"], + max_norm=experiment_config["train_args"]["max_norm"], ) # Train the model. - trainer.fit(model) + if train: + trainer.fit(model) # Run inference over test set. - if experiment_config["test"]: + if test: logger.info("Loading checkpoint with the best weights.") model.load_from_checkpoint(model_dir / "best.pt") logger.info("Running inference on test set.") - if experiment_config["criterion"]["type"] in CUSTOM_LOSSES: + if experiment_config["criterion"]["type"] == "EmbeddingLoss": logger.info("Evaluating embedding.") score = evaluate_embedding(model) else: @@ -313,7 +344,26 @@ def run_experiment( @click.option( "--nowandb", is_flag=False, help="If true, do not use wandb for this run." ) -def run_cli(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None: +@click.option("--test", is_flag=True, help="If true, test the model.") +@click.option("-v", "--verbose", count=True) +@click.option("--checkpoint", type=str, help="Path to the experiment.") +@click.option( + "--pretrained_weights", type=str, help="Path to pretrained model weights." +) +@click.option( + "--notrain", is_flag=False, is_eager=True, help="Do not train the model.", +) +def run_cli( + experiment_config: str, + gpu: int, + save: bool, + nowandb: bool, + notrain: bool, + test: bool, + verbose: int, + checkpoint: Optional[str] = None, + pretrained_weights: Optional[str] = None, +) -> None: """Run experiment.""" if gpu < 0: gpu_manager = GPUManager(True) @@ -322,7 +372,18 @@ def run_cli(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None experiment_config = json.loads(experiment_config) os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}" - run_experiment(experiment_config, save, device, use_wandb=not nowandb) + + run_experiment( + experiment_config, + save, + device, + use_wandb=not nowandb, + train=not notrain, + test=test, + verbose=verbose, + checkpoint=checkpoint, + pretrained_weights=pretrained_weights, + ) if __name__ == "__main__": diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py index 8c7b085..500b642 100644 --- a/src/training/trainer/callbacks/base.py +++ b/src/training/trainer/callbacks/base.py @@ -62,6 +62,14 @@ class Callback: """Called at the end of an epoch.""" pass + def on_test_begin(self) -> None: + """Called at the beginning of test.""" + pass + + def on_test_end(self) -> None: + """Called at the end of test.""" + pass + class CallbackList: """Container for abstracting away callback calls.""" @@ -92,7 +100,7 @@ class CallbackList: def append(self, callback: Type[Callback]) -> None: """Append new callback to callback list.""" - self.callbacks.append(callback) + self._callbacks.append(callback) def on_fit_begin(self) -> None: """Called when fit begins.""" @@ -104,6 +112,16 @@ class CallbackList: for callback in self._callbacks: callback.on_fit_end() + def on_test_begin(self) -> None: + """Called when test begins.""" + for callback in self._callbacks: + callback.on_test_begin() + + def on_test_end(self) -> None: + """Called when test ends.""" + for callback in self._callbacks: + callback.on_test_end() + def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: """Called at the beginning of an epoch.""" for callback in self._callbacks: diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py index 6fe06d3..a54e0a9 100644 --- a/src/training/trainer/callbacks/checkpoint.py +++ b/src/training/trainer/callbacks/checkpoint.py @@ -21,7 +21,7 @@ class Checkpoint(Callback): def __init__( self, - checkpoint_path: Path, + checkpoint_path: Union[str, Path], monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0, @@ -29,14 +29,14 @@ class Checkpoint(Callback): """Monitors a quantity that will allow us to determine the best model weights. Args: - checkpoint_path (Path): Path to the experiment with the checkpoint. + checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint. monitor (str): Name of the quantity to monitor. Defaults to "accuracy". mode (str): Description of parameter `mode`. Defaults to "auto". min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. """ super().__init__() - self.checkpoint_path = checkpoint_path + self.checkpoint_path = Path(checkpoint_path) self.monitor = monitor self.mode = mode self.min_delta = torch.tensor(min_delta) diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py index 907e292..630c434 100644 --- a/src/training/trainer/callbacks/lr_schedulers.py +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -22,7 +22,10 @@ class LRScheduler(Callback): def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every epoch.""" if self.interval == "epoch": - self.lr_scheduler.step() + if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__: + self.lr_scheduler.step(logs["val_loss"]) + else: + self.lr_scheduler.step() def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index d2df4d7..1627f17 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -64,37 +64,55 @@ class WandbImageLogger(Callback): """ super().__init__() + self.caption = None self.example_indices = example_indices + self.test_sample_indices = None self.num_examples = num_examples self.transpose = Transpose() if use_transpose else None def set_model(self, model: Type[Model]) -> None: """Sets the model and extracts validation images from the dataset.""" self.model = model + self.caption = "Validation Examples" if self.example_indices is None: self.example_indices = np.random.randint( 0, len(self.model.val_dataset), self.num_examples ) - self.val_images = self.model.val_dataset.dataset.data[self.example_indices] - self.val_targets = self.model.val_dataset.dataset.targets[self.example_indices] - self.val_targets = self.val_targets.tolist() + self.images = self.model.val_dataset.dataset.data[self.example_indices] + self.targets = self.model.val_dataset.dataset.targets[self.example_indices] + self.targets = self.targets.tolist() + + def on_test_begin(self) -> None: + """Get samples from test dataset.""" + self.caption = "Test Examples" + if self.test_sample_indices is None: + self.test_sample_indices = np.random.randint( + 0, len(self.model.test_dataset), self.num_examples + ) + self.images = self.model.test_dataset.data[self.test_sample_indices] + self.targets = self.model.test_dataset.targets[self.test_sample_indices] + self.targets = self.targets.tolist() + + def on_test_end(self) -> None: + """Log test images.""" + self.on_epoch_end(0, {}) def on_epoch_end(self, epoch: int, logs: Dict) -> None: """Get network predictions on validation images.""" images = [] - for i, image in enumerate(self.val_images): + for i, image in enumerate(self.images): image = self.transpose(image) if self.transpose is not None else image pred, conf = self.model.predict_on_image(image) - if isinstance(self.val_targets[i], list): + if isinstance(self.targets[i], list): ground_truth = "".join( [ self.model.mapper(int(target_index)) - for target_index in self.val_targets[i] + for target_index in self.targets[i] ] ).rstrip("_") else: - ground_truth = self.val_targets[i] + ground_truth = self.model.mapper(int(self.targets[i])) caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" images.append(wandb.Image(image, caption=caption)) - wandb.log({"examples": images}, commit=False) + wandb.log({f"{self.caption}": images}, commit=False) diff --git a/src/training/trainer/population_based_training/__init__.py b/src/training/trainer/population_based_training/__init__.py deleted file mode 100644 index 868d739..0000000 --- a/src/training/trainer/population_based_training/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""TBC.""" diff --git a/src/training/trainer/population_based_training/population_based_training.py b/src/training/trainer/population_based_training/population_based_training.py deleted file mode 100644 index 868d739..0000000 --- a/src/training/trainer/population_based_training/population_based_training.py +++ /dev/null @@ -1 +0,0 @@ -"""TBC.""" diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index bd6a491..223d9c6 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -4,6 +4,7 @@ from pathlib import Path import time from typing import Dict, List, Optional, Tuple, Type +from einops import rearrange from loguru import logger import numpy as np import torch @@ -27,12 +28,20 @@ class Trainer: # TODO: proper add teardown? - def __init__(self, max_epochs: int, callbacks: List[Type[Callback]],) -> None: + def __init__( + self, + max_epochs: int, + callbacks: List[Type[Callback]], + transformer_model: bool = False, + max_norm: float = 0.0, + ) -> None: """Initialization of the Trainer. Args: max_epochs (int): The maximum number of epochs in the training loop. callbacks (CallbackList): List of callbacks to be called. + transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False. + max_norm (float): Max norm for gradient clipping. Defaults to 0.0. """ # Training arguments. @@ -43,6 +52,10 @@ class Trainer: # Flag for setting callbacks. self.callbacks_configured = False + self.transformer_model = transformer_model + + self.max_norm = max_norm + # Model placeholders self.model = None @@ -97,10 +110,15 @@ class Trainer: # Forward pass. # Get the network prediction. - output = self.model.forward(data) + if self.transformer_model: + output = self.model.network.forward(data, targets[:, :-1]) + output = rearrange(output, "b t v -> (b t) v") + targets = rearrange(targets[:, 1:], "b t -> (b t)").long() + else: + output = self.model.forward(data) # Compute the loss. - loss = self.model.loss_fn(output, targets) + loss = self.model.criterion(output, targets) # Backward pass. # Clear the previous gradients. @@ -110,6 +128,11 @@ class Trainer: # Compute the gradients. loss.backward() + if self.max_norm > 0: + torch.nn.utils.clip_grad_norm_( + self.model.network.parameters(), self.max_norm + ) + # Perform updates using calculated gradients. self.model.optimizer.step() @@ -148,10 +171,15 @@ class Trainer: # Forward pass. # Get the network prediction. # Use SWA if available and using test dataset. - output = self.model.forward(data) + if self.transformer_model: + output = self.model.network.forward(data, targets[:, :-1]) + output = rearrange(output, "b t v -> (b t) v") + targets = rearrange(targets[:, 1:], "b t -> (b t)").long() + else: + output = self.model.forward(data) # Compute the loss. - loss = self.model.loss_fn(output, targets) + loss = self.model.criterion(output, targets) # Compute metrics. metrics = self.compute_metrics(output, targets, loss, loss_avg) @@ -237,6 +265,8 @@ class Trainer: # Configure callbacks. self._configure_callbacks() + self.callbacks.on_test_begin() + self.model.eval() # Check if SWA network is available. @@ -252,6 +282,8 @@ class Trainer: metrics = self.validation_step(batch, samples, loss_avg) summary.append(metrics) + self.callbacks.on_test_end() + # Compute mean of all test metrics. metrics_mean = { "test_" + metric: np.mean([x[metric] for x in summary]) -- cgit v1.2.3-70-g09d2