summaryrefslogtreecommitdiff
path: root/src/training/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/train.py')
-rw-r--r--src/training/train.py103
1 files changed, 46 insertions, 57 deletions
diff --git a/src/training/train.py b/src/training/train.py
index 4a452b6..8cd5110 100644
--- a/src/training/train.py
+++ b/src/training/train.py
@@ -1,8 +1,8 @@
"""Training script for PyTorch models."""
-from datetime import datetime
from pathlib import Path
-from typing import Callable, Dict, Optional
+import time
+from typing import Dict, Optional, Type
from loguru import logger
import numpy as np
@@ -11,6 +11,7 @@ from tqdm import tqdm, trange
from training.util import RunningAverage
import wandb
+from text_recognizer.models import Model
torch.backends.cudnn.benchmark = True
np.random.seed(4711)
@@ -18,17 +19,16 @@ torch.manual_seed(4711)
torch.cuda.manual_seed(4711)
-EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
-
-
class Trainer:
"""Trainer for training PyTorch models."""
# TODO implement wandb.
+ # TODO implement Bayesian parameter search.
def __init__(
self,
- model: Callable,
+ model: Type[Model],
+ model_dir: Path,
epochs: int,
val_metric: str = "accuracy",
checkpoint_path: Optional[Path] = None,
@@ -37,7 +37,8 @@ class Trainer:
"""Initialization of the Trainer.
Args:
- model (Callable): A model object.
+ model (Type[Model]): A model object.
+ model_dir (Path): Path to the model directory.
epochs (int): Number of epochs to train.
val_metric (str): The validation metric to evaluate the model on. Defaults to "accuracy".
checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None.
@@ -45,6 +46,7 @@ class Trainer:
"""
self.model = model
+ self.model_dir = model_dir
self.epochs = epochs
self.checkpoint_path = checkpoint_path
self.start_epoch = 0
@@ -58,7 +60,10 @@ class Trainer:
self.val_metric = val_metric
self.best_val_metric = 0.0
- logger.add(self.model.name + "_{time}.log")
+
+ # Parse the name of the experiment.
+ experiment_dir = str(self.model_dir.parents[1]).split("/")
+ self.experiment_name = experiment_dir[-2] + "/" + experiment_dir[-1]
def train(self) -> None:
"""Training loop."""
@@ -68,13 +73,13 @@ class Trainer:
# Running average for the loss.
loss_avg = RunningAverage()
- data_loader = self.model.data_loaders["train"]
+ data_loader = self.model.data_loaders("train")
with tqdm(
total=len(data_loader),
leave=False,
unit="step",
- bar_format="{n_fmt}/{total_fmt} {bar} {remaining} {rate_inv_fmt}{postfix}",
+ bar_format="{n_fmt}/{total_fmt} |{bar:20}| {remaining} {rate_inv_fmt}{postfix}",
) as t:
for data, targets in data_loader:
@@ -85,7 +90,7 @@ class Trainer:
# Forward pass.
# Get the network prediction.
- output = self.model.predict(data)
+ output = self.model.network(data)
# Compute the loss.
loss = self.model.criterion(output, targets)
@@ -105,16 +110,20 @@ class Trainer:
output = output.data.cpu()
targets = targets.data.cpu()
metrics = {
- metric: round(self.model.metrics[metric](output, targets), 4)
+ metric: self.model.metrics[metric](output, targets)
for metric in self.model.metrics
}
- metrics["loss"] = round(loss_avg(), 4)
+ metrics["loss"] = loss_avg()
# Update Tqdm progress bar.
t.set_postfix(**metrics)
t.update()
- def evaluate(self) -> Dict:
+ # If the model has a learning rate scheduler, compute a step.
+ if self.model.lr_scheduler is not None:
+ self.model.lr_scheduler.step()
+
+ def validate(self) -> Dict:
"""Evaluation loop.
Returns:
@@ -125,7 +134,7 @@ class Trainer:
self.model.eval()
# Running average for the loss.
- data_loader = self.model.data_loaders["val"]
+ data_loader = self.model.data_loaders("val")
# Running average for the loss.
loss_avg = RunningAverage()
@@ -137,7 +146,7 @@ class Trainer:
total=len(data_loader),
leave=False,
unit="step",
- bar_format="{n_fmt}/{total_fmt} {bar} {remaining} {rate_inv_fmt}{postfix}",
+ bar_format="{n_fmt}/{total_fmt} |{bar:20}| {remaining} {rate_inv_fmt}{postfix}",
) as t:
for data, targets in data_loader:
data, targets = (
@@ -145,22 +154,23 @@ class Trainer:
targets.to(self.model.device),
)
- # Forward pass.
- # Get the network prediction.
- output = self.model.predict(data)
+ with torch.no_grad():
+ # Forward pass.
+ # Get the network prediction.
+ output = self.model.network(data)
- # Compute the loss.
- loss = self.model.criterion(output, targets)
+ # Compute the loss.
+ loss = self.model.criterion(output, targets)
# Compute metrics.
loss_avg.update(loss.item())
output = output.data.cpu()
targets = targets.data.cpu()
metrics = {
- metric: round(self.model.metrics[metric](output, targets), 4)
+ metric: self.model.metrics[metric](output, targets)
for metric in self.model.metrics
}
- metrics["loss"] = round(loss.item(), 4)
+ metrics["loss"] = loss.item()
summary.append(metrics)
@@ -170,7 +180,7 @@ class Trainer:
# Compute mean of all metrics.
metrics_mean = {
- metric: np.mean(x[metric] for x in summary) for metric in summary[0]
+ metric: np.mean([x[metric] for x in summary]) for metric in summary[0]
}
metrics_str = " - ".join(f"{k}: {v}" for k, v in metrics_mean.items())
logger.debug(metrics_str)
@@ -179,55 +189,34 @@ class Trainer:
def fit(self) -> None:
"""Runs the training and evaluation loop."""
- # Create new experiment.
- EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True)
- experiment = datetime.now().strftime("%m%d_%H%M%S")
- experiment_dir = EXPERIMENTS_DIRNAME / self.model.network.__name__ / experiment
-
- # Create log and model directories.
- log_dir = experiment_dir / "log"
- model_dir = experiment_dir / "model"
-
- # Make sure the log directory exists.
- log_dir.mkdir(parents=True, exist_ok=True)
-
- logger.add(
- str(log_dir / "train.log"),
- format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
- )
-
- logger.debug(
- f"Running an experiment called {self.model.network.__name__}/{experiment}."
- )
-
- # PÅ•ints a summary of the network in terminal.
- self.model.summary()
+ logger.debug(f"Running an experiment called {self.experiment_name}.")
+ t_start = time.time()
# Run the training loop.
for epoch in trange(
- total=self.epochs,
+ self.epochs,
initial=self.start_epoch,
- leave=True,
- bar_format="{desc}: {n_fmt}/{total_fmt} {bar} {remaining}{postfix}",
+ leave=False,
+ bar_format="{desc}: {n_fmt}/{total_fmt} |{bar:10}| {remaining}{postfix}",
desc="Epoch",
):
# Perform one training pass over the training set.
self.train()
# Evaluate the model on the validation set.
- val_metrics = self.evaluate()
-
- # If the model has a learning rate scheduler, compute a step.
- if self.model.lr_scheduler is not None:
- self.model.lr_scheduler.step()
+ val_metrics = self.validate()
# The validation metric to evaluate the model on, e.g. accuracy.
val_metric = val_metrics[self.val_metric]
is_best = val_metric >= self.best_val_metric
-
+ self.best_val_metric = val_metric if is_best else self.best_val_metric
# Save checkpoint.
- self.model.save_checkpoint(model_dir, is_best, epoch, self.val_metric)
+ self.model.save_checkpoint(self.model_dir, is_best, epoch, self.val_metric)
if self.start_epoch > 0 and epoch + self.start_epoch == self.epochs:
logger.debug(f"Trained the model for {self.epochs} number of epochs.")
break
+
+ t_end = time.time()
+ t_training = t_end - t_start
+ logger.info(f"Training took {t_training:.2f} s.")