diff options
Diffstat (limited to 'src/training/train.py')
-rw-r--r-- | src/training/train.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/src/training/train.py b/src/training/train.py index 783de02..4a452b6 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -9,6 +9,8 @@ import numpy as np import torch from tqdm import tqdm, trange from training.util import RunningAverage +import wandb + torch.backends.cudnn.benchmark = True np.random.seed(4711) @@ -30,6 +32,7 @@ class Trainer: epochs: int, val_metric: str = "accuracy", checkpoint_path: Optional[Path] = None, + use_wandb: Optional[bool] = False, ) -> None: """Initialization of the Trainer. @@ -38,6 +41,7 @@ class Trainer: epochs (int): Number of epochs to train. val_metric (str): The validation metric to evaluate the model on. Defaults to "accuracy". checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None. + use_wandb (Optional[bool]): Sync training to wandb. """ self.model = model @@ -48,13 +52,16 @@ class Trainer: if self.checkpoint_path is not None: self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + if use_wandb: + # TODO implement wandb logging. + pass + self.val_metric = val_metric self.best_val_metric = 0.0 logger.add(self.model.name + "_{time}.log") def train(self) -> None: """Training loop.""" - # TODO add summary # Set model to traning mode. self.model.train() @@ -93,10 +100,6 @@ class Trainer: # Perform updates using calculated gradients. self.model.optimizer.step() - # Update the learning rate scheduler. - if self.model.lr_scheduler is not None: - self.model.lr_scheduler.step() - # Compute metrics. loss_avg.update(loss.item()) output = output.data.cpu() @@ -174,8 +177,8 @@ class Trainer: return metrics_mean - def run(self) -> None: - """Training and evaluation loop.""" + def fit(self) -> None: + """Runs the training and evaluation loop.""" # Create new experiment. EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) experiment = datetime.now().strftime("%m%d_%H%M%S") |