diff options
Diffstat (limited to 'src/training/trainer')
| -rw-r--r-- | src/training/trainer/callbacks/base.py | 20 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/checkpoint.py | 6 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/lr_schedulers.py | 5 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/wandb_callbacks.py | 34 | ||||
| -rw-r--r-- | src/training/trainer/population_based_training/__init__.py | 1 | ||||
| -rw-r--r-- | src/training/trainer/population_based_training/population_based_training.py | 1 | ||||
| -rw-r--r-- | src/training/trainer/train.py | 42 | 
7 files changed, 89 insertions, 20 deletions
| diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py index 8c7b085..500b642 100644 --- a/src/training/trainer/callbacks/base.py +++ b/src/training/trainer/callbacks/base.py @@ -62,6 +62,14 @@ class Callback:          """Called at the end of an epoch."""          pass +    def on_test_begin(self) -> None: +        """Called at the beginning of test.""" +        pass + +    def on_test_end(self) -> None: +        """Called at the end of test.""" +        pass +  class CallbackList:      """Container for abstracting away callback calls.""" @@ -92,7 +100,7 @@ class CallbackList:      def append(self, callback: Type[Callback]) -> None:          """Append new callback to callback list.""" -        self.callbacks.append(callback) +        self._callbacks.append(callback)      def on_fit_begin(self) -> None:          """Called when fit begins.""" @@ -104,6 +112,16 @@ class CallbackList:          for callback in self._callbacks:              callback.on_fit_end() +    def on_test_begin(self) -> None: +        """Called when test begins.""" +        for callback in self._callbacks: +            callback.on_test_begin() + +    def on_test_end(self) -> None: +        """Called when test ends.""" +        for callback in self._callbacks: +            callback.on_test_end() +      def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:          """Called at the beginning of an epoch."""          for callback in self._callbacks: diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py index 6fe06d3..a54e0a9 100644 --- a/src/training/trainer/callbacks/checkpoint.py +++ b/src/training/trainer/callbacks/checkpoint.py @@ -21,7 +21,7 @@ class Checkpoint(Callback):      def __init__(          self, -        checkpoint_path: Path, +        checkpoint_path: Union[str, Path],          monitor: str = "accuracy",          mode: str = "auto",          min_delta: float = 0.0, @@ -29,14 +29,14 @@ class Checkpoint(Callback):          """Monitors a quantity that will allow us to determine the best model weights.          Args: -            checkpoint_path (Path): Path to the experiment with the checkpoint. +            checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint.              monitor (str): Name of the quantity to monitor. Defaults to "accuracy".              mode (str): Description of parameter `mode`. Defaults to "auto".              min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.          """          super().__init__() -        self.checkpoint_path = checkpoint_path +        self.checkpoint_path = Path(checkpoint_path)          self.monitor = monitor          self.mode = mode          self.min_delta = torch.tensor(min_delta) diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py index 907e292..630c434 100644 --- a/src/training/trainer/callbacks/lr_schedulers.py +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -22,7 +22,10 @@ class LRScheduler(Callback):      def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:          """Takes a step at the end of every epoch."""          if self.interval == "epoch": -            self.lr_scheduler.step() +            if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__: +                self.lr_scheduler.step(logs["val_loss"]) +            else: +                self.lr_scheduler.step()      def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:          """Takes a step at the end of every training batch.""" diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index d2df4d7..1627f17 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -64,37 +64,55 @@ class WandbImageLogger(Callback):          """          super().__init__() +        self.caption = None          self.example_indices = example_indices +        self.test_sample_indices = None          self.num_examples = num_examples          self.transpose = Transpose() if use_transpose else None      def set_model(self, model: Type[Model]) -> None:          """Sets the model and extracts validation images from the dataset."""          self.model = model +        self.caption = "Validation Examples"          if self.example_indices is None:              self.example_indices = np.random.randint(                  0, len(self.model.val_dataset), self.num_examples              ) -        self.val_images = self.model.val_dataset.dataset.data[self.example_indices] -        self.val_targets = self.model.val_dataset.dataset.targets[self.example_indices] -        self.val_targets = self.val_targets.tolist() +        self.images = self.model.val_dataset.dataset.data[self.example_indices] +        self.targets = self.model.val_dataset.dataset.targets[self.example_indices] +        self.targets = self.targets.tolist() + +    def on_test_begin(self) -> None: +        """Get samples from test dataset.""" +        self.caption = "Test Examples" +        if self.test_sample_indices is None: +            self.test_sample_indices = np.random.randint( +                0, len(self.model.test_dataset), self.num_examples +            ) +        self.images = self.model.test_dataset.data[self.test_sample_indices] +        self.targets = self.model.test_dataset.targets[self.test_sample_indices] +        self.targets = self.targets.tolist() + +    def on_test_end(self) -> None: +        """Log test images.""" +        self.on_epoch_end(0, {})      def on_epoch_end(self, epoch: int, logs: Dict) -> None:          """Get network predictions on validation images."""          images = [] -        for i, image in enumerate(self.val_images): +        for i, image in enumerate(self.images):              image = self.transpose(image) if self.transpose is not None else image              pred, conf = self.model.predict_on_image(image) -            if isinstance(self.val_targets[i], list): +            if isinstance(self.targets[i], list):                  ground_truth = "".join(                      [                          self.model.mapper(int(target_index)) -                        for target_index in self.val_targets[i] +                        for target_index in self.targets[i]                      ]                  ).rstrip("_")              else: -                ground_truth = self.val_targets[i] +                ground_truth = self.model.mapper(int(self.targets[i]))              caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"              images.append(wandb.Image(image, caption=caption)) -        wandb.log({"examples": images}, commit=False) +        wandb.log({f"{self.caption}": images}, commit=False) diff --git a/src/training/trainer/population_based_training/__init__.py b/src/training/trainer/population_based_training/__init__.py deleted file mode 100644 index 868d739..0000000 --- a/src/training/trainer/population_based_training/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""TBC.""" diff --git a/src/training/trainer/population_based_training/population_based_training.py b/src/training/trainer/population_based_training/population_based_training.py deleted file mode 100644 index 868d739..0000000 --- a/src/training/trainer/population_based_training/population_based_training.py +++ /dev/null @@ -1 +0,0 @@ -"""TBC.""" diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index bd6a491..223d9c6 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -4,6 +4,7 @@ from pathlib import Path  import time  from typing import Dict, List, Optional, Tuple, Type +from einops import rearrange  from loguru import logger  import numpy as np  import torch @@ -27,12 +28,20 @@ class Trainer:      # TODO: proper add teardown? -    def __init__(self, max_epochs: int, callbacks: List[Type[Callback]],) -> None: +    def __init__( +        self, +        max_epochs: int, +        callbacks: List[Type[Callback]], +        transformer_model: bool = False, +        max_norm: float = 0.0, +    ) -> None:          """Initialization of the Trainer.          Args:              max_epochs (int): The maximum number of epochs in the training loop.              callbacks (CallbackList): List of callbacks to be called. +            transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False. +            max_norm (float): Max norm for gradient clipping. Defaults to 0.0.          """          # Training arguments. @@ -43,6 +52,10 @@ class Trainer:          # Flag for setting callbacks.          self.callbacks_configured = False +        self.transformer_model = transformer_model + +        self.max_norm = max_norm +          # Model placeholders          self.model = None @@ -97,10 +110,15 @@ class Trainer:          # Forward pass.          # Get the network prediction. -        output = self.model.forward(data) +        if self.transformer_model: +            output = self.model.network.forward(data, targets[:, :-1]) +            output = rearrange(output, "b t v -> (b t) v") +            targets = rearrange(targets[:, 1:], "b t -> (b t)").long() +        else: +            output = self.model.forward(data)          # Compute the loss. -        loss = self.model.loss_fn(output, targets) +        loss = self.model.criterion(output, targets)          # Backward pass.          # Clear the previous gradients. @@ -110,6 +128,11 @@ class Trainer:          # Compute the gradients.          loss.backward() +        if self.max_norm > 0: +            torch.nn.utils.clip_grad_norm_( +                self.model.network.parameters(), self.max_norm +            ) +          # Perform updates using calculated gradients.          self.model.optimizer.step() @@ -148,10 +171,15 @@ class Trainer:          # Forward pass.          # Get the network prediction.          # Use SWA if available and using test dataset. -        output = self.model.forward(data) +        if self.transformer_model: +            output = self.model.network.forward(data, targets[:, :-1]) +            output = rearrange(output, "b t v -> (b t) v") +            targets = rearrange(targets[:, 1:], "b t -> (b t)").long() +        else: +            output = self.model.forward(data)          # Compute the loss. -        loss = self.model.loss_fn(output, targets) +        loss = self.model.criterion(output, targets)          # Compute metrics.          metrics = self.compute_metrics(output, targets, loss, loss_avg) @@ -237,6 +265,8 @@ class Trainer:          # Configure callbacks.          self._configure_callbacks() +        self.callbacks.on_test_begin() +          self.model.eval()          # Check if SWA network is available. @@ -252,6 +282,8 @@ class Trainer:              metrics = self.validation_step(batch, samples, loss_avg)              summary.append(metrics) +        self.callbacks.on_test_end() +          # Compute mean of all test metrics.          metrics_mean = {              "test_" + metric: np.mean([x[metric] for x in summary]) |