diff options
Diffstat (limited to 'src/training/trainer')
-rw-r--r-- | src/training/trainer/callbacks/base.py | 20 | ||||
-rw-r--r-- | src/training/trainer/callbacks/checkpoint.py | 6 | ||||
-rw-r--r-- | src/training/trainer/callbacks/lr_schedulers.py | 5 | ||||
-rw-r--r-- | src/training/trainer/callbacks/wandb_callbacks.py | 34 | ||||
-rw-r--r-- | src/training/trainer/population_based_training/__init__.py | 1 | ||||
-rw-r--r-- | src/training/trainer/population_based_training/population_based_training.py | 1 | ||||
-rw-r--r-- | src/training/trainer/train.py | 42 |
7 files changed, 89 insertions, 20 deletions
diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py index 8c7b085..500b642 100644 --- a/src/training/trainer/callbacks/base.py +++ b/src/training/trainer/callbacks/base.py @@ -62,6 +62,14 @@ class Callback: """Called at the end of an epoch.""" pass + def on_test_begin(self) -> None: + """Called at the beginning of test.""" + pass + + def on_test_end(self) -> None: + """Called at the end of test.""" + pass + class CallbackList: """Container for abstracting away callback calls.""" @@ -92,7 +100,7 @@ class CallbackList: def append(self, callback: Type[Callback]) -> None: """Append new callback to callback list.""" - self.callbacks.append(callback) + self._callbacks.append(callback) def on_fit_begin(self) -> None: """Called when fit begins.""" @@ -104,6 +112,16 @@ class CallbackList: for callback in self._callbacks: callback.on_fit_end() + def on_test_begin(self) -> None: + """Called when test begins.""" + for callback in self._callbacks: + callback.on_test_begin() + + def on_test_end(self) -> None: + """Called when test ends.""" + for callback in self._callbacks: + callback.on_test_end() + def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: """Called at the beginning of an epoch.""" for callback in self._callbacks: diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py index 6fe06d3..a54e0a9 100644 --- a/src/training/trainer/callbacks/checkpoint.py +++ b/src/training/trainer/callbacks/checkpoint.py @@ -21,7 +21,7 @@ class Checkpoint(Callback): def __init__( self, - checkpoint_path: Path, + checkpoint_path: Union[str, Path], monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0, @@ -29,14 +29,14 @@ class Checkpoint(Callback): """Monitors a quantity that will allow us to determine the best model weights. Args: - checkpoint_path (Path): Path to the experiment with the checkpoint. + checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint. monitor (str): Name of the quantity to monitor. Defaults to "accuracy". mode (str): Description of parameter `mode`. Defaults to "auto". min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. """ super().__init__() - self.checkpoint_path = checkpoint_path + self.checkpoint_path = Path(checkpoint_path) self.monitor = monitor self.mode = mode self.min_delta = torch.tensor(min_delta) diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py index 907e292..630c434 100644 --- a/src/training/trainer/callbacks/lr_schedulers.py +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -22,7 +22,10 @@ class LRScheduler(Callback): def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every epoch.""" if self.interval == "epoch": - self.lr_scheduler.step() + if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__: + self.lr_scheduler.step(logs["val_loss"]) + else: + self.lr_scheduler.step() def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index d2df4d7..1627f17 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -64,37 +64,55 @@ class WandbImageLogger(Callback): """ super().__init__() + self.caption = None self.example_indices = example_indices + self.test_sample_indices = None self.num_examples = num_examples self.transpose = Transpose() if use_transpose else None def set_model(self, model: Type[Model]) -> None: """Sets the model and extracts validation images from the dataset.""" self.model = model + self.caption = "Validation Examples" if self.example_indices is None: self.example_indices = np.random.randint( 0, len(self.model.val_dataset), self.num_examples ) - self.val_images = self.model.val_dataset.dataset.data[self.example_indices] - self.val_targets = self.model.val_dataset.dataset.targets[self.example_indices] - self.val_targets = self.val_targets.tolist() + self.images = self.model.val_dataset.dataset.data[self.example_indices] + self.targets = self.model.val_dataset.dataset.targets[self.example_indices] + self.targets = self.targets.tolist() + + def on_test_begin(self) -> None: + """Get samples from test dataset.""" + self.caption = "Test Examples" + if self.test_sample_indices is None: + self.test_sample_indices = np.random.randint( + 0, len(self.model.test_dataset), self.num_examples + ) + self.images = self.model.test_dataset.data[self.test_sample_indices] + self.targets = self.model.test_dataset.targets[self.test_sample_indices] + self.targets = self.targets.tolist() + + def on_test_end(self) -> None: + """Log test images.""" + self.on_epoch_end(0, {}) def on_epoch_end(self, epoch: int, logs: Dict) -> None: """Get network predictions on validation images.""" images = [] - for i, image in enumerate(self.val_images): + for i, image in enumerate(self.images): image = self.transpose(image) if self.transpose is not None else image pred, conf = self.model.predict_on_image(image) - if isinstance(self.val_targets[i], list): + if isinstance(self.targets[i], list): ground_truth = "".join( [ self.model.mapper(int(target_index)) - for target_index in self.val_targets[i] + for target_index in self.targets[i] ] ).rstrip("_") else: - ground_truth = self.val_targets[i] + ground_truth = self.model.mapper(int(self.targets[i])) caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" images.append(wandb.Image(image, caption=caption)) - wandb.log({"examples": images}, commit=False) + wandb.log({f"{self.caption}": images}, commit=False) diff --git a/src/training/trainer/population_based_training/__init__.py b/src/training/trainer/population_based_training/__init__.py deleted file mode 100644 index 868d739..0000000 --- a/src/training/trainer/population_based_training/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""TBC.""" diff --git a/src/training/trainer/population_based_training/population_based_training.py b/src/training/trainer/population_based_training/population_based_training.py deleted file mode 100644 index 868d739..0000000 --- a/src/training/trainer/population_based_training/population_based_training.py +++ /dev/null @@ -1 +0,0 @@ -"""TBC.""" diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index bd6a491..223d9c6 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -4,6 +4,7 @@ from pathlib import Path import time from typing import Dict, List, Optional, Tuple, Type +from einops import rearrange from loguru import logger import numpy as np import torch @@ -27,12 +28,20 @@ class Trainer: # TODO: proper add teardown? - def __init__(self, max_epochs: int, callbacks: List[Type[Callback]],) -> None: + def __init__( + self, + max_epochs: int, + callbacks: List[Type[Callback]], + transformer_model: bool = False, + max_norm: float = 0.0, + ) -> None: """Initialization of the Trainer. Args: max_epochs (int): The maximum number of epochs in the training loop. callbacks (CallbackList): List of callbacks to be called. + transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False. + max_norm (float): Max norm for gradient clipping. Defaults to 0.0. """ # Training arguments. @@ -43,6 +52,10 @@ class Trainer: # Flag for setting callbacks. self.callbacks_configured = False + self.transformer_model = transformer_model + + self.max_norm = max_norm + # Model placeholders self.model = None @@ -97,10 +110,15 @@ class Trainer: # Forward pass. # Get the network prediction. - output = self.model.forward(data) + if self.transformer_model: + output = self.model.network.forward(data, targets[:, :-1]) + output = rearrange(output, "b t v -> (b t) v") + targets = rearrange(targets[:, 1:], "b t -> (b t)").long() + else: + output = self.model.forward(data) # Compute the loss. - loss = self.model.loss_fn(output, targets) + loss = self.model.criterion(output, targets) # Backward pass. # Clear the previous gradients. @@ -110,6 +128,11 @@ class Trainer: # Compute the gradients. loss.backward() + if self.max_norm > 0: + torch.nn.utils.clip_grad_norm_( + self.model.network.parameters(), self.max_norm + ) + # Perform updates using calculated gradients. self.model.optimizer.step() @@ -148,10 +171,15 @@ class Trainer: # Forward pass. # Get the network prediction. # Use SWA if available and using test dataset. - output = self.model.forward(data) + if self.transformer_model: + output = self.model.network.forward(data, targets[:, :-1]) + output = rearrange(output, "b t v -> (b t) v") + targets = rearrange(targets[:, 1:], "b t -> (b t)").long() + else: + output = self.model.forward(data) # Compute the loss. - loss = self.model.loss_fn(output, targets) + loss = self.model.criterion(output, targets) # Compute metrics. metrics = self.compute_metrics(output, targets, loss, loss_avg) @@ -237,6 +265,8 @@ class Trainer: # Configure callbacks. self._configure_callbacks() + self.callbacks.on_test_begin() + self.model.eval() # Check if SWA network is available. @@ -252,6 +282,8 @@ class Trainer: metrics = self.validation_step(batch, samples, loss_avg) summary.append(metrics) + self.callbacks.on_test_end() + # Compute mean of all test metrics. metrics_mean = { "test_" + metric: np.mean([x[metric] for x in summary]) |