diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-05 22:27:08 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-05 22:27:08 +0200 |
commit | 5a78fc2e33c28968a69d033cb10d638f4f63fed1 (patch) | |
tree | e11cea8366c848e5500f85968ee5369ff8d96b00 /src/training/train.py | |
parent | 7c4de6d88664d2ea1b084f316a11896dde3e1150 (diff) |
Working on getting experiment loop.
Diffstat (limited to 'src/training/train.py')
-rw-r--r-- | src/training/train.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/src/training/train.py b/src/training/train.py index 783de02..4a452b6 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -9,6 +9,8 @@ import numpy as np import torch from tqdm import tqdm, trange from training.util import RunningAverage +import wandb + torch.backends.cudnn.benchmark = True np.random.seed(4711) @@ -30,6 +32,7 @@ class Trainer: epochs: int, val_metric: str = "accuracy", checkpoint_path: Optional[Path] = None, + use_wandb: Optional[bool] = False, ) -> None: """Initialization of the Trainer. @@ -38,6 +41,7 @@ class Trainer: epochs (int): Number of epochs to train. val_metric (str): The validation metric to evaluate the model on. Defaults to "accuracy". checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None. + use_wandb (Optional[bool]): Sync training to wandb. """ self.model = model @@ -48,13 +52,16 @@ class Trainer: if self.checkpoint_path is not None: self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + if use_wandb: + # TODO implement wandb logging. + pass + self.val_metric = val_metric self.best_val_metric = 0.0 logger.add(self.model.name + "_{time}.log") def train(self) -> None: """Training loop.""" - # TODO add summary # Set model to traning mode. self.model.train() @@ -93,10 +100,6 @@ class Trainer: # Perform updates using calculated gradients. self.model.optimizer.step() - # Update the learning rate scheduler. - if self.model.lr_scheduler is not None: - self.model.lr_scheduler.step() - # Compute metrics. loss_avg.update(loss.item()) output = output.data.cpu() @@ -174,8 +177,8 @@ class Trainer: return metrics_mean - def run(self) -> None: - """Training and evaluation loop.""" + def fit(self) -> None: + """Runs the training and evaluation loop.""" # Create new experiment. EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) experiment = datetime.now().strftime("%m%d_%H%M%S") |