diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-22 23:18:08 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-22 23:18:08 +0200 |
commit | f473456c19558aaf8552df97a51d4e18cc69dfa8 (patch) | |
tree | 0d35ce2410ff623ba5fb433d616d95b67ecf7a98 /src/training/train.py | |
parent | ad3bd52530f4800d4fb05dfef3354921f95513af (diff) |
Working training loop and testing of trained CharacterModel.
Diffstat (limited to 'src/training/train.py')
-rw-r--r-- | src/training/train.py | 103 |
1 files changed, 46 insertions, 57 deletions
diff --git a/src/training/train.py b/src/training/train.py index 4a452b6..8cd5110 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -1,8 +1,8 @@ """Training script for PyTorch models.""" -from datetime import datetime from pathlib import Path -from typing import Callable, Dict, Optional +import time +from typing import Dict, Optional, Type from loguru import logger import numpy as np @@ -11,6 +11,7 @@ from tqdm import tqdm, trange from training.util import RunningAverage import wandb +from text_recognizer.models import Model torch.backends.cudnn.benchmark = True np.random.seed(4711) @@ -18,17 +19,16 @@ torch.manual_seed(4711) torch.cuda.manual_seed(4711) -EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" - - class Trainer: """Trainer for training PyTorch models.""" # TODO implement wandb. + # TODO implement Bayesian parameter search. def __init__( self, - model: Callable, + model: Type[Model], + model_dir: Path, epochs: int, val_metric: str = "accuracy", checkpoint_path: Optional[Path] = None, @@ -37,7 +37,8 @@ class Trainer: """Initialization of the Trainer. Args: - model (Callable): A model object. + model (Type[Model]): A model object. + model_dir (Path): Path to the model directory. epochs (int): Number of epochs to train. val_metric (str): The validation metric to evaluate the model on. Defaults to "accuracy". checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None. @@ -45,6 +46,7 @@ class Trainer: """ self.model = model + self.model_dir = model_dir self.epochs = epochs self.checkpoint_path = checkpoint_path self.start_epoch = 0 @@ -58,7 +60,10 @@ class Trainer: self.val_metric = val_metric self.best_val_metric = 0.0 - logger.add(self.model.name + "_{time}.log") + + # Parse the name of the experiment. + experiment_dir = str(self.model_dir.parents[1]).split("/") + self.experiment_name = experiment_dir[-2] + "/" + experiment_dir[-1] def train(self) -> None: """Training loop.""" @@ -68,13 +73,13 @@ class Trainer: # Running average for the loss. loss_avg = RunningAverage() - data_loader = self.model.data_loaders["train"] + data_loader = self.model.data_loaders("train") with tqdm( total=len(data_loader), leave=False, unit="step", - bar_format="{n_fmt}/{total_fmt} {bar} {remaining} {rate_inv_fmt}{postfix}", + bar_format="{n_fmt}/{total_fmt} |{bar:20}| {remaining} {rate_inv_fmt}{postfix}", ) as t: for data, targets in data_loader: @@ -85,7 +90,7 @@ class Trainer: # Forward pass. # Get the network prediction. - output = self.model.predict(data) + output = self.model.network(data) # Compute the loss. loss = self.model.criterion(output, targets) @@ -105,16 +110,20 @@ class Trainer: output = output.data.cpu() targets = targets.data.cpu() metrics = { - metric: round(self.model.metrics[metric](output, targets), 4) + metric: self.model.metrics[metric](output, targets) for metric in self.model.metrics } - metrics["loss"] = round(loss_avg(), 4) + metrics["loss"] = loss_avg() # Update Tqdm progress bar. t.set_postfix(**metrics) t.update() - def evaluate(self) -> Dict: + # If the model has a learning rate scheduler, compute a step. + if self.model.lr_scheduler is not None: + self.model.lr_scheduler.step() + + def validate(self) -> Dict: """Evaluation loop. Returns: @@ -125,7 +134,7 @@ class Trainer: self.model.eval() # Running average for the loss. - data_loader = self.model.data_loaders["val"] + data_loader = self.model.data_loaders("val") # Running average for the loss. loss_avg = RunningAverage() @@ -137,7 +146,7 @@ class Trainer: total=len(data_loader), leave=False, unit="step", - bar_format="{n_fmt}/{total_fmt} {bar} {remaining} {rate_inv_fmt}{postfix}", + bar_format="{n_fmt}/{total_fmt} |{bar:20}| {remaining} {rate_inv_fmt}{postfix}", ) as t: for data, targets in data_loader: data, targets = ( @@ -145,22 +154,23 @@ class Trainer: targets.to(self.model.device), ) - # Forward pass. - # Get the network prediction. - output = self.model.predict(data) + with torch.no_grad(): + # Forward pass. + # Get the network prediction. + output = self.model.network(data) - # Compute the loss. - loss = self.model.criterion(output, targets) + # Compute the loss. + loss = self.model.criterion(output, targets) # Compute metrics. loss_avg.update(loss.item()) output = output.data.cpu() targets = targets.data.cpu() metrics = { - metric: round(self.model.metrics[metric](output, targets), 4) + metric: self.model.metrics[metric](output, targets) for metric in self.model.metrics } - metrics["loss"] = round(loss.item(), 4) + metrics["loss"] = loss.item() summary.append(metrics) @@ -170,7 +180,7 @@ class Trainer: # Compute mean of all metrics. metrics_mean = { - metric: np.mean(x[metric] for x in summary) for metric in summary[0] + metric: np.mean([x[metric] for x in summary]) for metric in summary[0] } metrics_str = " - ".join(f"{k}: {v}" for k, v in metrics_mean.items()) logger.debug(metrics_str) @@ -179,55 +189,34 @@ class Trainer: def fit(self) -> None: """Runs the training and evaluation loop.""" - # Create new experiment. - EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) - experiment = datetime.now().strftime("%m%d_%H%M%S") - experiment_dir = EXPERIMENTS_DIRNAME / self.model.network.__name__ / experiment - - # Create log and model directories. - log_dir = experiment_dir / "log" - model_dir = experiment_dir / "model" - - # Make sure the log directory exists. - log_dir.mkdir(parents=True, exist_ok=True) - - logger.add( - str(log_dir / "train.log"), - format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", - ) - - logger.debug( - f"Running an experiment called {self.model.network.__name__}/{experiment}." - ) - - # PÅ•ints a summary of the network in terminal. - self.model.summary() + logger.debug(f"Running an experiment called {self.experiment_name}.") + t_start = time.time() # Run the training loop. for epoch in trange( - total=self.epochs, + self.epochs, initial=self.start_epoch, - leave=True, - bar_format="{desc}: {n_fmt}/{total_fmt} {bar} {remaining}{postfix}", + leave=False, + bar_format="{desc}: {n_fmt}/{total_fmt} |{bar:10}| {remaining}{postfix}", desc="Epoch", ): # Perform one training pass over the training set. self.train() # Evaluate the model on the validation set. - val_metrics = self.evaluate() - - # If the model has a learning rate scheduler, compute a step. - if self.model.lr_scheduler is not None: - self.model.lr_scheduler.step() + val_metrics = self.validate() # The validation metric to evaluate the model on, e.g. accuracy. val_metric = val_metrics[self.val_metric] is_best = val_metric >= self.best_val_metric - + self.best_val_metric = val_metric if is_best else self.best_val_metric # Save checkpoint. - self.model.save_checkpoint(model_dir, is_best, epoch, self.val_metric) + self.model.save_checkpoint(self.model_dir, is_best, epoch, self.val_metric) if self.start_epoch > 0 and epoch + self.start_epoch == self.epochs: logger.debug(f"Trained the model for {self.epochs} number of epochs.") break + + t_end = time.time() + t_training = t_end - t_start + logger.info(f"Training took {t_training:.2f} s.") |