summaryrefslogtreecommitdiff
path: root/src/training/trainer
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
commitdc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (patch)
tree1b5fc0d06952e13727e85c4f973a26d277068453 /src/training/trainer
parente181195a699d7fa237f256d90ab4dedffc03d405 (diff)
new updates
Diffstat (limited to 'src/training/trainer')
-rw-r--r--src/training/trainer/callbacks/base.py20
-rw-r--r--src/training/trainer/callbacks/checkpoint.py6
-rw-r--r--src/training/trainer/callbacks/lr_schedulers.py5
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py34
-rw-r--r--src/training/trainer/population_based_training/__init__.py1
-rw-r--r--src/training/trainer/population_based_training/population_based_training.py1
-rw-r--r--src/training/trainer/train.py42
7 files changed, 89 insertions, 20 deletions
diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py
index 8c7b085..500b642 100644
--- a/src/training/trainer/callbacks/base.py
+++ b/src/training/trainer/callbacks/base.py
@@ -62,6 +62,14 @@ class Callback:
"""Called at the end of an epoch."""
pass
+ def on_test_begin(self) -> None:
+ """Called at the beginning of test."""
+ pass
+
+ def on_test_end(self) -> None:
+ """Called at the end of test."""
+ pass
+
class CallbackList:
"""Container for abstracting away callback calls."""
@@ -92,7 +100,7 @@ class CallbackList:
def append(self, callback: Type[Callback]) -> None:
"""Append new callback to callback list."""
- self.callbacks.append(callback)
+ self._callbacks.append(callback)
def on_fit_begin(self) -> None:
"""Called when fit begins."""
@@ -104,6 +112,16 @@ class CallbackList:
for callback in self._callbacks:
callback.on_fit_end()
+ def on_test_begin(self) -> None:
+ """Called when test begins."""
+ for callback in self._callbacks:
+ callback.on_test_begin()
+
+ def on_test_end(self) -> None:
+ """Called when test ends."""
+ for callback in self._callbacks:
+ callback.on_test_end()
+
def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch."""
for callback in self._callbacks:
diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py
index 6fe06d3..a54e0a9 100644
--- a/src/training/trainer/callbacks/checkpoint.py
+++ b/src/training/trainer/callbacks/checkpoint.py
@@ -21,7 +21,7 @@ class Checkpoint(Callback):
def __init__(
self,
- checkpoint_path: Path,
+ checkpoint_path: Union[str, Path],
monitor: str = "accuracy",
mode: str = "auto",
min_delta: float = 0.0,
@@ -29,14 +29,14 @@ class Checkpoint(Callback):
"""Monitors a quantity that will allow us to determine the best model weights.
Args:
- checkpoint_path (Path): Path to the experiment with the checkpoint.
+ checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint.
monitor (str): Name of the quantity to monitor. Defaults to "accuracy".
mode (str): Description of parameter `mode`. Defaults to "auto".
min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
"""
super().__init__()
- self.checkpoint_path = checkpoint_path
+ self.checkpoint_path = Path(checkpoint_path)
self.monitor = monitor
self.mode = mode
self.min_delta = torch.tensor(min_delta)
diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py
index 907e292..630c434 100644
--- a/src/training/trainer/callbacks/lr_schedulers.py
+++ b/src/training/trainer/callbacks/lr_schedulers.py
@@ -22,7 +22,10 @@ class LRScheduler(Callback):
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every epoch."""
if self.interval == "epoch":
- self.lr_scheduler.step()
+ if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__:
+ self.lr_scheduler.step(logs["val_loss"])
+ else:
+ self.lr_scheduler.step()
def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index d2df4d7..1627f17 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -64,37 +64,55 @@ class WandbImageLogger(Callback):
"""
super().__init__()
+ self.caption = None
self.example_indices = example_indices
+ self.test_sample_indices = None
self.num_examples = num_examples
self.transpose = Transpose() if use_transpose else None
def set_model(self, model: Type[Model]) -> None:
"""Sets the model and extracts validation images from the dataset."""
self.model = model
+ self.caption = "Validation Examples"
if self.example_indices is None:
self.example_indices = np.random.randint(
0, len(self.model.val_dataset), self.num_examples
)
- self.val_images = self.model.val_dataset.dataset.data[self.example_indices]
- self.val_targets = self.model.val_dataset.dataset.targets[self.example_indices]
- self.val_targets = self.val_targets.tolist()
+ self.images = self.model.val_dataset.dataset.data[self.example_indices]
+ self.targets = self.model.val_dataset.dataset.targets[self.example_indices]
+ self.targets = self.targets.tolist()
+
+ def on_test_begin(self) -> None:
+ """Get samples from test dataset."""
+ self.caption = "Test Examples"
+ if self.test_sample_indices is None:
+ self.test_sample_indices = np.random.randint(
+ 0, len(self.model.test_dataset), self.num_examples
+ )
+ self.images = self.model.test_dataset.data[self.test_sample_indices]
+ self.targets = self.model.test_dataset.targets[self.test_sample_indices]
+ self.targets = self.targets.tolist()
+
+ def on_test_end(self) -> None:
+ """Log test images."""
+ self.on_epoch_end(0, {})
def on_epoch_end(self, epoch: int, logs: Dict) -> None:
"""Get network predictions on validation images."""
images = []
- for i, image in enumerate(self.val_images):
+ for i, image in enumerate(self.images):
image = self.transpose(image) if self.transpose is not None else image
pred, conf = self.model.predict_on_image(image)
- if isinstance(self.val_targets[i], list):
+ if isinstance(self.targets[i], list):
ground_truth = "".join(
[
self.model.mapper(int(target_index))
- for target_index in self.val_targets[i]
+ for target_index in self.targets[i]
]
).rstrip("_")
else:
- ground_truth = self.val_targets[i]
+ ground_truth = self.model.mapper(int(self.targets[i]))
caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"
images.append(wandb.Image(image, caption=caption))
- wandb.log({"examples": images}, commit=False)
+ wandb.log({f"{self.caption}": images}, commit=False)
diff --git a/src/training/trainer/population_based_training/__init__.py b/src/training/trainer/population_based_training/__init__.py
deleted file mode 100644
index 868d739..0000000
--- a/src/training/trainer/population_based_training/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""TBC."""
diff --git a/src/training/trainer/population_based_training/population_based_training.py b/src/training/trainer/population_based_training/population_based_training.py
deleted file mode 100644
index 868d739..0000000
--- a/src/training/trainer/population_based_training/population_based_training.py
+++ /dev/null
@@ -1 +0,0 @@
-"""TBC."""
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
index bd6a491..223d9c6 100644
--- a/src/training/trainer/train.py
+++ b/src/training/trainer/train.py
@@ -4,6 +4,7 @@ from pathlib import Path
import time
from typing import Dict, List, Optional, Tuple, Type
+from einops import rearrange
from loguru import logger
import numpy as np
import torch
@@ -27,12 +28,20 @@ class Trainer:
# TODO: proper add teardown?
- def __init__(self, max_epochs: int, callbacks: List[Type[Callback]],) -> None:
+ def __init__(
+ self,
+ max_epochs: int,
+ callbacks: List[Type[Callback]],
+ transformer_model: bool = False,
+ max_norm: float = 0.0,
+ ) -> None:
"""Initialization of the Trainer.
Args:
max_epochs (int): The maximum number of epochs in the training loop.
callbacks (CallbackList): List of callbacks to be called.
+ transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False.
+ max_norm (float): Max norm for gradient clipping. Defaults to 0.0.
"""
# Training arguments.
@@ -43,6 +52,10 @@ class Trainer:
# Flag for setting callbacks.
self.callbacks_configured = False
+ self.transformer_model = transformer_model
+
+ self.max_norm = max_norm
+
# Model placeholders
self.model = None
@@ -97,10 +110,15 @@ class Trainer:
# Forward pass.
# Get the network prediction.
- output = self.model.forward(data)
+ if self.transformer_model:
+ output = self.model.network.forward(data, targets[:, :-1])
+ output = rearrange(output, "b t v -> (b t) v")
+ targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
+ else:
+ output = self.model.forward(data)
# Compute the loss.
- loss = self.model.loss_fn(output, targets)
+ loss = self.model.criterion(output, targets)
# Backward pass.
# Clear the previous gradients.
@@ -110,6 +128,11 @@ class Trainer:
# Compute the gradients.
loss.backward()
+ if self.max_norm > 0:
+ torch.nn.utils.clip_grad_norm_(
+ self.model.network.parameters(), self.max_norm
+ )
+
# Perform updates using calculated gradients.
self.model.optimizer.step()
@@ -148,10 +171,15 @@ class Trainer:
# Forward pass.
# Get the network prediction.
# Use SWA if available and using test dataset.
- output = self.model.forward(data)
+ if self.transformer_model:
+ output = self.model.network.forward(data, targets[:, :-1])
+ output = rearrange(output, "b t v -> (b t) v")
+ targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
+ else:
+ output = self.model.forward(data)
# Compute the loss.
- loss = self.model.loss_fn(output, targets)
+ loss = self.model.criterion(output, targets)
# Compute metrics.
metrics = self.compute_metrics(output, targets, loss, loss_avg)
@@ -237,6 +265,8 @@ class Trainer:
# Configure callbacks.
self._configure_callbacks()
+ self.callbacks.on_test_begin()
+
self.model.eval()
# Check if SWA network is available.
@@ -252,6 +282,8 @@ class Trainer:
metrics = self.validation_step(batch, samples, loss_avg)
summary.append(metrics)
+ self.callbacks.on_test_end()
+
# Compute mean of all test metrics.
metrics_mean = {
"test_" + metric: np.mean([x[metric] for x in summary])