summaryrefslogtreecommitdiff
path: root/src/training/trainer/callbacks
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/trainer/callbacks')
-rw-r--r--src/training/trainer/callbacks/__init__.py15
-rw-r--r--src/training/trainer/callbacks/base.py78
-rw-r--r--src/training/trainer/callbacks/checkpoint.py95
-rw-r--r--src/training/trainer/callbacks/lr_schedulers.py52
-rw-r--r--src/training/trainer/callbacks/progress_bar.py19
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py32
6 files changed, 190 insertions, 101 deletions
diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py
index 5942276..c81e4bf 100644
--- a/src/training/trainer/callbacks/__init__.py
+++ b/src/training/trainer/callbacks/__init__.py
@@ -1,7 +1,16 @@
"""The callback modules used in the training script."""
-from .base import Callback, CallbackList, Checkpoint
+from .base import Callback, CallbackList
+from .checkpoint import Checkpoint
from .early_stopping import EarlyStopping
-from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR
+from .lr_schedulers import (
+ CosineAnnealingLR,
+ CyclicLR,
+ MultiStepLR,
+ OneCycleLR,
+ ReduceLROnPlateau,
+ StepLR,
+ SWA,
+)
from .progress_bar import ProgressBar
from .wandb_callbacks import WandbCallback, WandbImageLogger
@@ -9,6 +18,7 @@ __all__ = [
"Callback",
"CallbackList",
"Checkpoint",
+ "CosineAnnealingLR",
"EarlyStopping",
"WandbCallback",
"WandbImageLogger",
@@ -18,4 +28,5 @@ __all__ = [
"ProgressBar",
"ReduceLROnPlateau",
"StepLR",
+ "SWA",
]
diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py
index 8df94f3..8c7b085 100644
--- a/src/training/trainer/callbacks/base.py
+++ b/src/training/trainer/callbacks/base.py
@@ -168,81 +168,3 @@ class CallbackList:
def __iter__(self) -> iter:
"""Iter function for callback list."""
return iter(self._callbacks)
-
-
-class Checkpoint(Callback):
- """Saving model parameters at the end of each epoch."""
-
- mode_dict = {
- "min": torch.lt,
- "max": torch.gt,
- }
-
- def __init__(
- self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0
- ) -> None:
- """Monitors a quantity that will allow us to determine the best model weights.
-
- Args:
- monitor (str): Name of the quantity to monitor. Defaults to "accuracy".
- mode (str): Description of parameter `mode`. Defaults to "auto".
- min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
-
- """
- super().__init__()
- self.monitor = monitor
- self.mode = mode
- self.min_delta = torch.tensor(min_delta)
-
- if mode not in ["auto", "min", "max"]:
- logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.")
-
- self.mode = "auto"
-
- if self.mode == "auto":
- if "accuracy" in self.monitor:
- self.mode = "max"
- else:
- self.mode = "min"
- logger.debug(
- f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}."
- )
-
- torch_inf = torch.tensor(np.inf)
- self.min_delta *= 1 if self.monitor_op == torch.gt else -1
- self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
-
- @property
- def monitor_op(self) -> float:
- """Returns the comparison method."""
- return self.mode_dict[self.mode]
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Saves a checkpoint for the network parameters.
-
- Args:
- epoch (int): The current epoch.
- logs (Dict): The log containing the monitored metrics.
-
- """
- current = self.get_monitor_value(logs)
- if current is None:
- return
- if self.monitor_op(current - self.min_delta, self.best_score):
- self.best_score = current
- is_best = True
- else:
- is_best = False
-
- self.model.save_checkpoint(is_best, epoch, self.monitor)
-
- def get_monitor_value(self, logs: Dict) -> Union[float, None]:
- """Extracts the monitored value."""
- monitor_value = logs.get(self.monitor)
- if monitor_value is None:
- logger.warning(
- f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available"
- + f"metrics are: {','.join(list(logs.keys()))}"
- )
- return None
- return monitor_value
diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py
new file mode 100644
index 0000000..6fe06d3
--- /dev/null
+++ b/src/training/trainer/callbacks/checkpoint.py
@@ -0,0 +1,95 @@
+"""Callback checkpoint for training models."""
+from enum import Enum
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from loguru import logger
+import numpy as np
+import torch
+from training.trainer.callbacks import Callback
+
+from text_recognizer.models import Model
+
+
+class Checkpoint(Callback):
+ """Saving model parameters at the end of each epoch."""
+
+ mode_dict = {
+ "min": torch.lt,
+ "max": torch.gt,
+ }
+
+ def __init__(
+ self,
+ checkpoint_path: Path,
+ monitor: str = "accuracy",
+ mode: str = "auto",
+ min_delta: float = 0.0,
+ ) -> None:
+ """Monitors a quantity that will allow us to determine the best model weights.
+
+ Args:
+ checkpoint_path (Path): Path to the experiment with the checkpoint.
+ monitor (str): Name of the quantity to monitor. Defaults to "accuracy".
+ mode (str): Description of parameter `mode`. Defaults to "auto".
+ min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
+
+ """
+ super().__init__()
+ self.checkpoint_path = checkpoint_path
+ self.monitor = monitor
+ self.mode = mode
+ self.min_delta = torch.tensor(min_delta)
+
+ if mode not in ["auto", "min", "max"]:
+ logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.")
+
+ self.mode = "auto"
+
+ if self.mode == "auto":
+ if "accuracy" in self.monitor:
+ self.mode = "max"
+ else:
+ self.mode = "min"
+ logger.debug(
+ f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}."
+ )
+
+ torch_inf = torch.tensor(np.inf)
+ self.min_delta *= 1 if self.monitor_op == torch.gt else -1
+ self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
+
+ @property
+ def monitor_op(self) -> float:
+ """Returns the comparison method."""
+ return self.mode_dict[self.mode]
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """Saves a checkpoint for the network parameters.
+
+ Args:
+ epoch (int): The current epoch.
+ logs (Dict): The log containing the monitored metrics.
+
+ """
+ current = self.get_monitor_value(logs)
+ if current is None:
+ return
+ if self.monitor_op(current - self.min_delta, self.best_score):
+ self.best_score = current
+ is_best = True
+ else:
+ is_best = False
+
+ self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor)
+
+ def get_monitor_value(self, logs: Dict) -> Union[float, None]:
+ """Extracts the monitored value."""
+ monitor_value = logs.get(self.monitor)
+ if monitor_value is None:
+ logger.warning(
+ f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available"
+ + f" metrics are: {','.join(list(logs.keys()))}"
+ )
+ return None
+ return monitor_value
diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py
index ba2226a..bb41d2d 100644
--- a/src/training/trainer/callbacks/lr_schedulers.py
+++ b/src/training/trainer/callbacks/lr_schedulers.py
@@ -1,6 +1,7 @@
"""Callbacks for learning rate schedulers."""
from typing import Callable, Dict, List, Optional, Type
+from torch.optim.swa_utils import update_bn
from training.trainer.callbacks import Callback
from text_recognizer.models import Model
@@ -95,3 +96,54 @@ class OneCycleLR(Callback):
def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
self.lr_scheduler.step()
+
+
+class CosineAnnealingLR(Callback):
+ """Callback for Cosine Annealing."""
+
+ def __init__(self) -> None:
+ """Initializes the callback."""
+ super().__init__()
+ self.lr_scheduler = None
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and lr scheduler."""
+ self.model = model
+ self.lr_scheduler = self.model.lr_scheduler
+
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+ """Takes a step at the end of every epoch."""
+ self.lr_scheduler.step()
+
+
+class SWA(Callback):
+ """Stochastic Weight Averaging callback."""
+
+ def __init__(self) -> None:
+ """Initializes the callback."""
+ super().__init__()
+ self.swa_scheduler = None
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and lr scheduler."""
+ self.model = model
+ self.swa_start = self.model.swa_start
+ self.swa_scheduler = self.model.lr_scheduler
+ self.lr_scheduler = self.model.lr_scheduler
+
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+ """Takes a step at the end of every training batch."""
+ if epoch > self.swa_start:
+ self.model.swa_network.update_parameters(self.model.network)
+ self.swa_scheduler.step()
+ else:
+ self.lr_scheduler.step()
+
+ def on_fit_end(self) -> None:
+ """Update batch norm statistics for the swa model at the end of training."""
+ if self.model.swa_network:
+ update_bn(
+ self.model.val_dataloader(),
+ self.model.swa_network,
+ device=self.model.device,
+ )
diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py
index 1970747..7829fa0 100644
--- a/src/training/trainer/callbacks/progress_bar.py
+++ b/src/training/trainer/callbacks/progress_bar.py
@@ -18,11 +18,11 @@ class ProgressBar(Callback):
def _configure_progress_bar(self) -> None:
"""Configures the tqdm progress bar with custom bar format."""
self.progress_bar = tqdm(
- total=len(self.model.data_loaders["train"]),
- leave=True,
- unit="step",
+ total=len(self.model.train_dataloader()),
+ leave=False,
+ unit="steps",
mininterval=self.log_batch_frequency,
- bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",
+ bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",
)
def _key_abbreviations(self, logs: Dict) -> Dict:
@@ -34,13 +34,16 @@ class ProgressBar(Callback):
return {rename(key): value for key, value in logs.items()}
- def on_fit_begin(self) -> None:
- """Creates a tqdm progress bar."""
- self._configure_progress_bar()
+ # def on_fit_begin(self) -> None:
+ # """Creates a tqdm progress bar."""
+ # self._configure_progress_bar()
def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None:
"""Updates the description with the current epoch."""
- self.progress_bar.reset()
+ if epoch == 1:
+ self._configure_progress_bar()
+ else:
+ self.progress_bar.reset()
self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}")
def on_epoch_end(self, epoch: int, logs: Dict) -> None:
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index e44c745..6643a44 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -2,7 +2,8 @@
from typing import Callable, Dict, List, Optional, Type
import numpy as np
-from torchvision.transforms import Compose, ToTensor
+import torch
+from torchvision.transforms import ToTensor
from training.trainer.callbacks import Callback
import wandb
@@ -50,43 +51,48 @@ class WandbImageLogger(Callback):
self,
example_indices: Optional[List] = None,
num_examples: int = 4,
- transfroms: Optional[Callable] = None,
+ use_transpose: Optional[bool] = False,
) -> None:
"""Initializes the WandbImageLogger with the model to train.
Args:
example_indices (Optional[List]): Indices for validation images. Defaults to None.
num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
- transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to
- None.
+ use_transpose (Optional[bool]): Use transpose on image or not. Defaults to False.
"""
super().__init__()
self.example_indices = example_indices
self.num_examples = num_examples
- self.transfroms = transfroms
- if self.transfroms is None:
- self.transforms = Compose([Transpose()])
+ self.transpose = Transpose() if use_transpose else None
def set_model(self, model: Type[Model]) -> None:
"""Sets the model and extracts validation images from the dataset."""
self.model = model
- data_loader = self.model.data_loaders["val"]
if self.example_indices is None:
self.example_indices = np.random.randint(
- 0, len(data_loader.dataset.data), self.num_examples
+ 0, len(self.model.val_dataset), self.num_examples
)
- self.val_images = data_loader.dataset.data[self.example_indices]
- self.val_targets = data_loader.dataset.targets[self.example_indices].numpy()
+ self.val_images = self.model.val_dataset.dataset.data[self.example_indices]
+ self.val_targets = self.model.val_dataset.dataset.targets[self.example_indices]
+ self.val_targets = self.val_targets.tolist()
def on_epoch_end(self, epoch: int, logs: Dict) -> None:
"""Get network predictions on validation images."""
images = []
for i, image in enumerate(self.val_images):
- image = self.transforms(image)
+ image = self.transpose(image) if self.transpose is not None else image
pred, conf = self.model.predict_on_image(image)
- ground_truth = self.model.mapper(int(self.val_targets[i]))
+ if isinstance(self.val_targets[i], list):
+ ground_truth = "".join(
+ [
+ self.model.mapper(int(target_index))
+ for target_index in self.val_targets[i]
+ ]
+ ).rstrip("_")
+ else:
+ ground_truth = self.val_targets[i]
caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"
images.append(wandb.Image(image, caption=caption))