summaryrefslogtreecommitdiff
path: root/src/training
diff options
context:
space:
mode:
Diffstat (limited to 'src/training')
-rw-r--r--src/training/experiments/default_config_emnist.yml1
-rw-r--r--src/training/experiments/embedding_experiment.yml64
-rw-r--r--src/training/experiments/line_ctc_experiment.yml91
-rw-r--r--src/training/experiments/sample_experiment.yml1
-rw-r--r--src/training/prepare_experiments.py2
-rw-r--r--src/training/run_experiment.py161
-rw-r--r--src/training/trainer/callbacks/base.py20
-rw-r--r--src/training/trainer/callbacks/checkpoint.py6
-rw-r--r--src/training/trainer/callbacks/lr_schedulers.py5
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py34
-rw-r--r--src/training/trainer/population_based_training/__init__.py1
-rw-r--r--src/training/trainer/population_based_training/population_based_training.py1
-rw-r--r--src/training/trainer/train.py42
13 files changed, 266 insertions, 163 deletions
diff --git a/src/training/experiments/default_config_emnist.yml b/src/training/experiments/default_config_emnist.yml
index 12a0a9d..bf2ed0a 100644
--- a/src/training/experiments/default_config_emnist.yml
+++ b/src/training/experiments/default_config_emnist.yml
@@ -66,4 +66,5 @@ callback_args:
null
verbosity: 1 # 0, 1, 2
resume_experiment: null
+train: true
validation_metric: val_accuracy
diff --git a/src/training/experiments/embedding_experiment.yml b/src/training/experiments/embedding_experiment.yml
new file mode 100644
index 0000000..1e5f941
--- /dev/null
+++ b/src/training/experiments/embedding_experiment.yml
@@ -0,0 +1,64 @@
+experiment_group: Embedding Experiments
+experiments:
+ - train_args:
+ transformer_model: false
+ batch_size: &batch_size 256
+ max_epochs: &max_epochs 32
+ input_shape: [[1, 28, 28]]
+ dataset:
+ type: EmnistDataset
+ args:
+ sample_to_balance: true
+ subsample_fraction: null
+ transform: null
+ target_transform: null
+ seed: 4711
+ train_args:
+ num_workers: 8
+ train_fraction: 0.85
+ batch_size: *batch_size
+ model: CharacterModel
+ metrics: []
+ network:
+ type: DenseNet
+ args:
+ growth_rate: 4
+ block_config: [4, 4]
+ in_channels: 1
+ base_channels: 24
+ num_classes: 128
+ bn_size: 4
+ dropout_rate: 0.1
+ classifier: true
+ activation: elu
+ criterion:
+ type: EmbeddingLoss
+ args:
+ margin: 0.2
+ type_of_triplets: semihard
+ optimizer:
+ type: AdamW
+ args:
+ lr: 1.e-02
+ betas: [0.9, 0.999]
+ eps: 1.e-08
+ weight_decay: 5.e-4
+ amsgrad: false
+ lr_scheduler:
+ type: CosineAnnealingLR
+ args:
+ T_max: *max_epochs
+ callbacks: [Checkpoint, ProgressBar, WandbCallback]
+ callback_args:
+ Checkpoint:
+ monitor: val_loss
+ mode: min
+ ProgressBar:
+ epochs: *max_epochs
+ WandbCallback:
+ log_batch_frequency: 10
+ verbosity: 1 # 0, 1, 2
+ resume_experiment: null
+ train: true
+ test: true
+ test_metric: mean_average_precision_at_r
diff --git a/src/training/experiments/line_ctc_experiment.yml b/src/training/experiments/line_ctc_experiment.yml
deleted file mode 100644
index 432d1cc..0000000
--- a/src/training/experiments/line_ctc_experiment.yml
+++ /dev/null
@@ -1,91 +0,0 @@
-experiment_group: Lines Experiments
-experiments:
- - train_args:
- batch_size: 42
- max_epochs: &max_epochs 32
- dataset:
- type: IamLinesDataset
- args:
- subsample_fraction: null
- transform: null
- target_transform: null
- train_args:
- num_workers: 8
- train_fraction: 0.85
- model: LineCTCModel
- metrics: [cer, wer]
- network:
- type: LineRecurrentNetwork
- args:
- backbone: ResidualNetwork
- backbone_args:
- in_channels: 1
- num_classes: 64 # Embedding
- depths: [2,2]
- block_sizes: [32,64]
- activation: selu
- stn: false
- # encoder: ResidualNetwork
- # encoder_args:
- # pretrained: training/experiments/CharacterModel_EmnistDataset_ResidualNetwork/0917_203601/model/best.pt
- # freeze: false
- flatten: false
- input_size: 64
- hidden_size: 64
- bidirectional: true
- num_layers: 2
- num_classes: 80
- patch_size: [28, 18]
- stride: [1, 4]
- criterion:
- type: CTCLoss
- args:
- blank: 79
- optimizer:
- type: AdamW
- args:
- lr: 1.e-02
- betas: [0.9, 0.999]
- eps: 1.e-08
- weight_decay: 5.e-4
- amsgrad: false
- lr_scheduler:
- type: OneCycleLR
- args:
- max_lr: 1.e-02
- epochs: *max_epochs
- anneal_strategy: cos
- pct_start: 0.475
- cycle_momentum: true
- base_momentum: 0.85
- max_momentum: 0.9
- div_factor: 10
- final_div_factor: 10000
- interval: step
- # lr_scheduler:
- # type: CosineAnnealingLR
- # args:
- # T_max: *max_epochs
- swa_args:
- start: 24
- lr: 5.e-2
- callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger] # EarlyStopping]
- callback_args:
- Checkpoint:
- monitor: val_loss
- mode: min
- ProgressBar:
- epochs: *max_epochs
- # EarlyStopping:
- # monitor: val_loss
- # min_delta: 0.0
- # patience: 10
- # mode: min
- WandbCallback:
- log_batch_frequency: 10
- WandbImageLogger:
- num_examples: 6
- verbosity: 1 # 0, 1, 2
- resume_experiment: null
- test: true
- test_metric: test_cer
diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml
index 8664a15..a073a87 100644
--- a/src/training/experiments/sample_experiment.yml
+++ b/src/training/experiments/sample_experiment.yml
@@ -95,5 +95,6 @@ experiments:
use_transpose: true
verbosity: 0 # 0, 1, 2
resume_experiment: null
+ train: true
test: true
test_metric: test_accuracy
diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py
index e00540c..6e20bcd 100644
--- a/src/training/prepare_experiments.py
+++ b/src/training/prepare_experiments.py
@@ -1,9 +1,7 @@
"""Run a experiment from a config file."""
import json
-from subprocess import run
import click
-from loguru import logger
import yaml
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index a347d9f..0510d5c 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -6,12 +6,15 @@ import json
import os
from pathlib import Path
import re
-from typing import Callable, Dict, List, Tuple, Type
+from typing import Callable, Dict, List, Optional, Tuple, Type
+import warnings
+import adabelief_pytorch
import click
from loguru import logger
import numpy as np
import torch
+from torchsummary import summary
from tqdm import tqdm
from training.gpu_manager import GPUManager
from training.trainer.callbacks import Callback, CallbackList
@@ -21,26 +24,23 @@ import yaml
from text_recognizer.models import Model
-from text_recognizer.networks import losses
-
+from text_recognizer.networks import loss as custom_loss_module
EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
-CUSTOM_LOSSES = ["EmbeddingLoss"]
DEFAULT_TRAIN_ARGS = {"batch_size": 64, "epochs": 16}
-def get_level(experiment_config: Dict) -> int:
+def _get_level(verbose: int) -> int:
"""Sets the logger level."""
- if experiment_config["verbosity"] == 0:
- return 40
- elif experiment_config["verbosity"] == 1:
- return 20
- else:
- return 10
+ levels = {0: 40, 1: 20, 2: 10}
+ verbose = verbose if verbose <= 2 else 2
+ return levels[verbose]
-def create_experiment_dir(experiment_config: Dict) -> Path:
+def _create_experiment_dir(
+ experiment_config: Dict, checkpoint: Optional[str] = None
+) -> Path:
"""Create new experiment."""
EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True)
experiment_dir = EXPERIMENTS_DIRNAME / (
@@ -48,19 +48,21 @@ def create_experiment_dir(experiment_config: Dict) -> Path:
+ f"{experiment_config['dataset']['type']}_"
+ f"{experiment_config['network']['type']}"
)
- if experiment_config["resume_experiment"] is None:
+
+ if checkpoint is None:
experiment = datetime.now().strftime("%m%d_%H%M%S")
logger.debug(f"Creating a new experiment called {experiment}")
else:
available_experiments = glob(str(experiment_dir) + "/*")
available_experiments.sort()
- if experiment_config["resume_experiment"] == "last":
+ if checkpoint == "last":
experiment = available_experiments[-1]
logger.debug(f"Resuming the latest experiment {experiment}")
else:
- experiment = experiment_config["resume_experiment"]
+ experiment = checkpoint
if not str(experiment_dir / experiment) in available_experiments:
raise FileNotFoundError("Experiment does not exist.")
+ logger.debug(f"Resuming the from experiment {checkpoint}")
experiment_dir = experiment_dir / experiment
@@ -71,14 +73,10 @@ def create_experiment_dir(experiment_config: Dict) -> Path:
return experiment_dir, log_dir, model_dir
-def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]:
+def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dict]:
"""Loads all modules and arguments."""
- # Import the data loader arguments.
- train_args = experiment_config.get("train_args", {})
-
# Load the dataset module.
dataset_args = experiment_config.get("dataset", {})
- dataset_args["train_args"]["batch_size"] = train_args["batch_size"]
datasets_module = importlib.import_module("text_recognizer.datasets")
dataset_ = getattr(datasets_module, dataset_args["type"])
@@ -102,21 +100,24 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
network_args = experiment_config["network"].get("args", {})
# Criterion
- if experiment_config["criterion"]["type"] in CUSTOM_LOSSES:
- criterion_ = getattr(losses, experiment_config["criterion"]["type"])
- criterion_args = experiment_config["criterion"].get("args", {})
+ if experiment_config["criterion"]["type"] in custom_loss_module.__all__:
+ criterion_ = getattr(custom_loss_module, experiment_config["criterion"]["type"])
else:
criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"])
- criterion_args = experiment_config["criterion"].get("args", {})
+ criterion_args = experiment_config["criterion"].get("args", {}) or {}
# Optimizers
- optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"])
+ if experiment_config["optimizer"]["type"] == "AdaBelief":
+ warnings.filterwarnings("ignore", category=UserWarning)
+ optimizer_ = getattr(adabelief_pytorch, experiment_config["optimizer"]["type"])
+ else:
+ optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"])
optimizer_args = experiment_config["optimizer"].get("args", {})
# Learning rate scheduler
lr_scheduler_ = None
lr_scheduler_args = None
- if experiment_config["lr_scheduler"] is not None:
+ if "lr_scheduler" in experiment_config:
lr_scheduler_ = getattr(
torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"]
)
@@ -146,10 +147,12 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
return model_class_, model_args
-def configure_callbacks(experiment_config: Dict, model_dir: Dict) -> CallbackList:
+def _configure_callbacks(experiment_config: Dict, model_dir: Path) -> CallbackList:
"""Configure a callback list for trainer."""
if "Checkpoint" in experiment_config["callback_args"]:
- experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = model_dir
+ experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = str(
+ model_dir
+ )
# Initializes callbacks.
callback_modules = importlib.import_module("training.trainer.callbacks")
@@ -161,13 +164,13 @@ def configure_callbacks(experiment_config: Dict, model_dir: Dict) -> CallbackLis
return callbacks
-def configure_logger(experiment_config: Dict, log_dir: Path) -> None:
+def _configure_logger(log_dir: Path, verbose: int = 0) -> None:
"""Configure the loguru logger for output to terminal and disk."""
# Have to remove default logger to get tqdm to work properly.
logger.remove()
# Fetch verbosity level.
- level = get_level(experiment_config)
+ level = _get_level(verbose)
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level)
logger.add(
@@ -176,20 +179,29 @@ def configure_logger(experiment_config: Dict, log_dir: Path) -> None:
)
-def save_config(experiment_dir: Path, experiment_config: Dict) -> None:
+def _save_config(experiment_dir: Path, experiment_config: Dict) -> None:
"""Copy config to experiment directory."""
config_path = experiment_dir / "config.yml"
with open(str(config_path), "w") as f:
yaml.dump(experiment_config, f)
-def load_from_checkpoint(model: Type[Model], log_dir: Path, model_dir: Path) -> None:
+def _load_from_checkpoint(
+ model: Type[Model], model_dir: Path, pretrained_weights: str = None,
+) -> None:
"""If checkpoint exists, load model weights and optimizers from checkpoint."""
# Get checkpoint path.
- checkpoint_path = model_dir / "last.pt"
+ if pretrained_weights is not None:
+ logger.info(f"Loading weights from {pretrained_weights}.")
+ checkpoint_path = (
+ EXPERIMENTS_DIRNAME / Path(pretrained_weights) / "model" / "best.pt"
+ )
+ else:
+ logger.info(f"Loading weights from {model_dir}.")
+ checkpoint_path = model_dir / "last.pt"
if checkpoint_path.exists():
- logger.info("Loading and resuming training from last checkpoint.")
- model.load_checkpoint(checkpoint_path)
+ logger.info("Loading and resuming training from checkpoint.")
+ model.load_from_checkpoint(checkpoint_path)
def evaluate_embedding(model: Type[Model]) -> Dict:
@@ -217,38 +229,50 @@ def evaluate_embedding(model: Type[Model]) -> Dict:
def run_experiment(
- experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False
+ experiment_config: Dict,
+ save_weights: bool,
+ device: str,
+ use_wandb: bool,
+ train: bool,
+ test: bool,
+ verbose: int = 0,
+ checkpoint: Optional[str] = None,
+ pretrained_weights: Optional[str] = None,
) -> None:
"""Runs an experiment."""
logger.info(f"Experiment config: {json.dumps(experiment_config)}")
# Create new experiment.
- experiment_dir, log_dir, model_dir = create_experiment_dir(experiment_config)
+ experiment_dir, log_dir, model_dir = _create_experiment_dir(
+ experiment_config, checkpoint
+ )
# Make sure the log/model directory exists.
log_dir.mkdir(parents=True, exist_ok=True)
model_dir.mkdir(parents=True, exist_ok=True)
# Load the modules and model arguments.
- model_class_, model_args = load_modules_and_arguments(experiment_config)
+ model_class_, model_args = _load_modules_and_arguments(experiment_config)
# Initializes the model with experiment config.
model = model_class_(**model_args, device=device)
- callbacks = configure_callbacks(experiment_config, model_dir)
+ callbacks = _configure_callbacks(experiment_config, model_dir)
# Setup logger.
- configure_logger(experiment_config, log_dir)
+ _configure_logger(log_dir, verbose)
# Load from checkpoint if resuming an experiment.
- if experiment_config["resume_experiment"] is not None:
- load_from_checkpoint(model, log_dir, model_dir)
+ resume = False
+ if checkpoint is not None or pretrained_weights is not None:
+ resume = True
+ _load_from_checkpoint(model, model_dir, pretrained_weights)
logger.info(f"The class mapping is {model.mapping}")
# Initializes Weights & Biases
if use_wandb:
- wandb.init(project="text-recognizer", config=experiment_config)
+ wandb.init(project="text-recognizer", config=experiment_config, resume=resume)
# Lets W&B save the model and track the gradients and optional parameters.
wandb.watch(model.network)
@@ -265,23 +289,30 @@ def run_experiment(
experiment_config["device"] = device
# Save the config used in the experiment folder.
- save_config(experiment_dir, experiment_config)
+ _save_config(experiment_dir, experiment_config)
+
+ # Prints a summary of the network in terminal.
+ model.summary(experiment_config["train_args"]["input_shape"])
# Load trainer.
trainer = Trainer(
- max_epochs=experiment_config["train_args"]["max_epochs"], callbacks=callbacks,
+ max_epochs=experiment_config["train_args"]["max_epochs"],
+ callbacks=callbacks,
+ transformer_model=experiment_config["train_args"]["transformer_model"],
+ max_norm=experiment_config["train_args"]["max_norm"],
)
# Train the model.
- trainer.fit(model)
+ if train:
+ trainer.fit(model)
# Run inference over test set.
- if experiment_config["test"]:
+ if test:
logger.info("Loading checkpoint with the best weights.")
model.load_from_checkpoint(model_dir / "best.pt")
logger.info("Running inference on test set.")
- if experiment_config["criterion"]["type"] in CUSTOM_LOSSES:
+ if experiment_config["criterion"]["type"] == "EmbeddingLoss":
logger.info("Evaluating embedding.")
score = evaluate_embedding(model)
else:
@@ -313,7 +344,26 @@ def run_experiment(
@click.option(
"--nowandb", is_flag=False, help="If true, do not use wandb for this run."
)
-def run_cli(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None:
+@click.option("--test", is_flag=True, help="If true, test the model.")
+@click.option("-v", "--verbose", count=True)
+@click.option("--checkpoint", type=str, help="Path to the experiment.")
+@click.option(
+ "--pretrained_weights", type=str, help="Path to pretrained model weights."
+)
+@click.option(
+ "--notrain", is_flag=False, is_eager=True, help="Do not train the model.",
+)
+def run_cli(
+ experiment_config: str,
+ gpu: int,
+ save: bool,
+ nowandb: bool,
+ notrain: bool,
+ test: bool,
+ verbose: int,
+ checkpoint: Optional[str] = None,
+ pretrained_weights: Optional[str] = None,
+) -> None:
"""Run experiment."""
if gpu < 0:
gpu_manager = GPUManager(True)
@@ -322,7 +372,18 @@ def run_cli(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None
experiment_config = json.loads(experiment_config)
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}"
- run_experiment(experiment_config, save, device, use_wandb=not nowandb)
+
+ run_experiment(
+ experiment_config,
+ save,
+ device,
+ use_wandb=not nowandb,
+ train=not notrain,
+ test=test,
+ verbose=verbose,
+ checkpoint=checkpoint,
+ pretrained_weights=pretrained_weights,
+ )
if __name__ == "__main__":
diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py
index 8c7b085..500b642 100644
--- a/src/training/trainer/callbacks/base.py
+++ b/src/training/trainer/callbacks/base.py
@@ -62,6 +62,14 @@ class Callback:
"""Called at the end of an epoch."""
pass
+ def on_test_begin(self) -> None:
+ """Called at the beginning of test."""
+ pass
+
+ def on_test_end(self) -> None:
+ """Called at the end of test."""
+ pass
+
class CallbackList:
"""Container for abstracting away callback calls."""
@@ -92,7 +100,7 @@ class CallbackList:
def append(self, callback: Type[Callback]) -> None:
"""Append new callback to callback list."""
- self.callbacks.append(callback)
+ self._callbacks.append(callback)
def on_fit_begin(self) -> None:
"""Called when fit begins."""
@@ -104,6 +112,16 @@ class CallbackList:
for callback in self._callbacks:
callback.on_fit_end()
+ def on_test_begin(self) -> None:
+ """Called when test begins."""
+ for callback in self._callbacks:
+ callback.on_test_begin()
+
+ def on_test_end(self) -> None:
+ """Called when test ends."""
+ for callback in self._callbacks:
+ callback.on_test_end()
+
def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch."""
for callback in self._callbacks:
diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py
index 6fe06d3..a54e0a9 100644
--- a/src/training/trainer/callbacks/checkpoint.py
+++ b/src/training/trainer/callbacks/checkpoint.py
@@ -21,7 +21,7 @@ class Checkpoint(Callback):
def __init__(
self,
- checkpoint_path: Path,
+ checkpoint_path: Union[str, Path],
monitor: str = "accuracy",
mode: str = "auto",
min_delta: float = 0.0,
@@ -29,14 +29,14 @@ class Checkpoint(Callback):
"""Monitors a quantity that will allow us to determine the best model weights.
Args:
- checkpoint_path (Path): Path to the experiment with the checkpoint.
+ checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint.
monitor (str): Name of the quantity to monitor. Defaults to "accuracy".
mode (str): Description of parameter `mode`. Defaults to "auto".
min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
"""
super().__init__()
- self.checkpoint_path = checkpoint_path
+ self.checkpoint_path = Path(checkpoint_path)
self.monitor = monitor
self.mode = mode
self.min_delta = torch.tensor(min_delta)
diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py
index 907e292..630c434 100644
--- a/src/training/trainer/callbacks/lr_schedulers.py
+++ b/src/training/trainer/callbacks/lr_schedulers.py
@@ -22,7 +22,10 @@ class LRScheduler(Callback):
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every epoch."""
if self.interval == "epoch":
- self.lr_scheduler.step()
+ if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__:
+ self.lr_scheduler.step(logs["val_loss"])
+ else:
+ self.lr_scheduler.step()
def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index d2df4d7..1627f17 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -64,37 +64,55 @@ class WandbImageLogger(Callback):
"""
super().__init__()
+ self.caption = None
self.example_indices = example_indices
+ self.test_sample_indices = None
self.num_examples = num_examples
self.transpose = Transpose() if use_transpose else None
def set_model(self, model: Type[Model]) -> None:
"""Sets the model and extracts validation images from the dataset."""
self.model = model
+ self.caption = "Validation Examples"
if self.example_indices is None:
self.example_indices = np.random.randint(
0, len(self.model.val_dataset), self.num_examples
)
- self.val_images = self.model.val_dataset.dataset.data[self.example_indices]
- self.val_targets = self.model.val_dataset.dataset.targets[self.example_indices]
- self.val_targets = self.val_targets.tolist()
+ self.images = self.model.val_dataset.dataset.data[self.example_indices]
+ self.targets = self.model.val_dataset.dataset.targets[self.example_indices]
+ self.targets = self.targets.tolist()
+
+ def on_test_begin(self) -> None:
+ """Get samples from test dataset."""
+ self.caption = "Test Examples"
+ if self.test_sample_indices is None:
+ self.test_sample_indices = np.random.randint(
+ 0, len(self.model.test_dataset), self.num_examples
+ )
+ self.images = self.model.test_dataset.data[self.test_sample_indices]
+ self.targets = self.model.test_dataset.targets[self.test_sample_indices]
+ self.targets = self.targets.tolist()
+
+ def on_test_end(self) -> None:
+ """Log test images."""
+ self.on_epoch_end(0, {})
def on_epoch_end(self, epoch: int, logs: Dict) -> None:
"""Get network predictions on validation images."""
images = []
- for i, image in enumerate(self.val_images):
+ for i, image in enumerate(self.images):
image = self.transpose(image) if self.transpose is not None else image
pred, conf = self.model.predict_on_image(image)
- if isinstance(self.val_targets[i], list):
+ if isinstance(self.targets[i], list):
ground_truth = "".join(
[
self.model.mapper(int(target_index))
- for target_index in self.val_targets[i]
+ for target_index in self.targets[i]
]
).rstrip("_")
else:
- ground_truth = self.val_targets[i]
+ ground_truth = self.model.mapper(int(self.targets[i]))
caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"
images.append(wandb.Image(image, caption=caption))
- wandb.log({"examples": images}, commit=False)
+ wandb.log({f"{self.caption}": images}, commit=False)
diff --git a/src/training/trainer/population_based_training/__init__.py b/src/training/trainer/population_based_training/__init__.py
deleted file mode 100644
index 868d739..0000000
--- a/src/training/trainer/population_based_training/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""TBC."""
diff --git a/src/training/trainer/population_based_training/population_based_training.py b/src/training/trainer/population_based_training/population_based_training.py
deleted file mode 100644
index 868d739..0000000
--- a/src/training/trainer/population_based_training/population_based_training.py
+++ /dev/null
@@ -1 +0,0 @@
-"""TBC."""
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
index bd6a491..223d9c6 100644
--- a/src/training/trainer/train.py
+++ b/src/training/trainer/train.py
@@ -4,6 +4,7 @@ from pathlib import Path
import time
from typing import Dict, List, Optional, Tuple, Type
+from einops import rearrange
from loguru import logger
import numpy as np
import torch
@@ -27,12 +28,20 @@ class Trainer:
# TODO: proper add teardown?
- def __init__(self, max_epochs: int, callbacks: List[Type[Callback]],) -> None:
+ def __init__(
+ self,
+ max_epochs: int,
+ callbacks: List[Type[Callback]],
+ transformer_model: bool = False,
+ max_norm: float = 0.0,
+ ) -> None:
"""Initialization of the Trainer.
Args:
max_epochs (int): The maximum number of epochs in the training loop.
callbacks (CallbackList): List of callbacks to be called.
+ transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False.
+ max_norm (float): Max norm for gradient clipping. Defaults to 0.0.
"""
# Training arguments.
@@ -43,6 +52,10 @@ class Trainer:
# Flag for setting callbacks.
self.callbacks_configured = False
+ self.transformer_model = transformer_model
+
+ self.max_norm = max_norm
+
# Model placeholders
self.model = None
@@ -97,10 +110,15 @@ class Trainer:
# Forward pass.
# Get the network prediction.
- output = self.model.forward(data)
+ if self.transformer_model:
+ output = self.model.network.forward(data, targets[:, :-1])
+ output = rearrange(output, "b t v -> (b t) v")
+ targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
+ else:
+ output = self.model.forward(data)
# Compute the loss.
- loss = self.model.loss_fn(output, targets)
+ loss = self.model.criterion(output, targets)
# Backward pass.
# Clear the previous gradients.
@@ -110,6 +128,11 @@ class Trainer:
# Compute the gradients.
loss.backward()
+ if self.max_norm > 0:
+ torch.nn.utils.clip_grad_norm_(
+ self.model.network.parameters(), self.max_norm
+ )
+
# Perform updates using calculated gradients.
self.model.optimizer.step()
@@ -148,10 +171,15 @@ class Trainer:
# Forward pass.
# Get the network prediction.
# Use SWA if available and using test dataset.
- output = self.model.forward(data)
+ if self.transformer_model:
+ output = self.model.network.forward(data, targets[:, :-1])
+ output = rearrange(output, "b t v -> (b t) v")
+ targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
+ else:
+ output = self.model.forward(data)
# Compute the loss.
- loss = self.model.loss_fn(output, targets)
+ loss = self.model.criterion(output, targets)
# Compute metrics.
metrics = self.compute_metrics(output, targets, loss, loss_avg)
@@ -237,6 +265,8 @@ class Trainer:
# Configure callbacks.
self._configure_callbacks()
+ self.callbacks.on_test_begin()
+
self.model.eval()
# Check if SWA network is available.
@@ -252,6 +282,8 @@ class Trainer:
metrics = self.validation_step(batch, samples, loss_avg)
summary.append(metrics)
+ self.callbacks.on_test_end()
+
# Compute mean of all test metrics.
metrics_mean = {
"test_" + metric: np.mean([x[metric] for x in summary])