summaryrefslogtreecommitdiff
path: root/src/training/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/train.py')
-rw-r--r--src/training/train.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/src/training/train.py b/src/training/train.py
index 783de02..4a452b6 100644
--- a/src/training/train.py
+++ b/src/training/train.py
@@ -9,6 +9,8 @@ import numpy as np
import torch
from tqdm import tqdm, trange
from training.util import RunningAverage
+import wandb
+
torch.backends.cudnn.benchmark = True
np.random.seed(4711)
@@ -30,6 +32,7 @@ class Trainer:
epochs: int,
val_metric: str = "accuracy",
checkpoint_path: Optional[Path] = None,
+ use_wandb: Optional[bool] = False,
) -> None:
"""Initialization of the Trainer.
@@ -38,6 +41,7 @@ class Trainer:
epochs (int): Number of epochs to train.
val_metric (str): The validation metric to evaluate the model on. Defaults to "accuracy".
checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None.
+ use_wandb (Optional[bool]): Sync training to wandb.
"""
self.model = model
@@ -48,13 +52,16 @@ class Trainer:
if self.checkpoint_path is not None:
self.start_epoch = self.model.load_checkpoint(self.checkpoint_path)
+ if use_wandb:
+ # TODO implement wandb logging.
+ pass
+
self.val_metric = val_metric
self.best_val_metric = 0.0
logger.add(self.model.name + "_{time}.log")
def train(self) -> None:
"""Training loop."""
- # TODO add summary
# Set model to traning mode.
self.model.train()
@@ -93,10 +100,6 @@ class Trainer:
# Perform updates using calculated gradients.
self.model.optimizer.step()
- # Update the learning rate scheduler.
- if self.model.lr_scheduler is not None:
- self.model.lr_scheduler.step()
-
# Compute metrics.
loss_avg.update(loss.item())
output = output.data.cpu()
@@ -174,8 +177,8 @@ class Trainer:
return metrics_mean
- def run(self) -> None:
- """Training and evaluation loop."""
+ def fit(self) -> None:
+ """Runs the training and evaluation loop."""
# Create new experiment.
EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True)
experiment = datetime.now().strftime("%m%d_%H%M%S")