diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-08 22:26:15 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-08 22:26:15 +0200 |
commit | 72befa6d8cc4c7ecf698512a97424641ee81725a (patch) | |
tree | af2ad5662eb8385c51c33c00761bd6f6bf29d8ee /training | |
parent | 544d37d9c3b8b57798d57988d6f000c19ab6b074 (diff) |
Move callbacks to training folder, refactor
Diffstat (limited to 'training')
-rw-r--r-- | training/callbacks/__init__.py | 1 | ||||
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 211 | ||||
-rw-r--r-- | training/conf/callbacks/wandb.yaml | 8 | ||||
-rw-r--r-- | training/run.py | 6 | ||||
-rw-r--r-- | training/utils.py | 2 |
5 files changed, 220 insertions, 8 deletions
diff --git a/training/callbacks/__init__.py b/training/callbacks/__init__.py new file mode 100644 index 0000000..82d8ce3 --- /dev/null +++ b/training/callbacks/__init__.py @@ -0,0 +1 @@ +"""Module for PyTorch Lightning callbacks.""" diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py new file mode 100644 index 0000000..d9d81f6 --- /dev/null +++ b/training/callbacks/wandb_callbacks.py @@ -0,0 +1,211 @@ +"""Weights and Biases callbacks.""" +from pathlib import Path +from typing import List + +import attr +import wandb +from pytorch_lightning import Callback, LightningModule, Trainer +from pytorch_lightning.loggers import LoggerCollection, WandbLogger + + +def get_wandb_logger(trainer: Trainer) -> WandbLogger: + """Safely get W&B logger from Trainer.""" + + if isinstance(trainer.logger, WandbLogger): + return trainer.logger + + if isinstance(trainer.logger, LoggerCollection): + for logger in trainer.logger: + if isinstance(logger, WandbLogger): + return logger + + raise Exception("Weight and Biases logger not found for some reason...") + + +@attr.s +class WatchModel(Callback): + """Make W&B watch the model at the beginning of the run.""" + + log: str = attr.ib(default="gradients") + log_freq: int = attr.ib(default=100) + + def __attrs_pre_init__(self) -> None: + super().__init__() + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Watches model weights with wandb.""" + logger = get_wandb_logger(trainer) + logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) + + +@attr.s +class UploadCodeAsArtifact(Callback): + """Upload all *.py files to W&B as an artifact, at the beginning of the run.""" + + project_dir: Path = attr.ib(converter=Path) + + def __attrs_pre_init__(self) -> None: + super().__init__() + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Uploads project code as an artifact.""" + logger = get_wandb_logger(trainer) + experiment = logger.experiment + artifact = wandb.Artifact("project-source", type="code") + for filepath in self.project_dir.glob("**/*.py"): + artifact.add_file(filepath) + + experiment.use_artifact(artifact) + + +@attr.s +class UploadCheckpointAsArtifact(Callback): + """Upload checkpoint to wandb as an artifact, at the end of a run.""" + + ckpt_dir: Path = attr.ib(converter=Path) + upload_best_only: bool = attr.ib() + + def __attrs_pre_init__(self) -> None: + super().__init__() + + def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Uploads model checkpoint to W&B.""" + logger = get_wandb_logger(trainer) + experiment = logger.experiment + ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") + + if self.upload_best_only: + ckpts.add_file(trainer.checkpoint_callback.best_model_path) + else: + for ckpt in (self.ckpt_dir).glob("**/*.ckpt"): + ckpts.add_file(ckpt) + + experiment.use_artifact(ckpts) + + +@attr.s +class LogTextPredictions(Callback): + """Logs a validation batch with image to text transcription.""" + + num_samples: int = attr.ib(default=8) + ready: bool = attr.ib(default=True) + + def __attrs_pre_init__(self) -> None: + super().__init__() + + def _log_predictions( + stage: str, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Logs the predicted text contained in the images.""" + if not self.ready: + return None + + logger = get_wandb_logger(trainer) + experiment = logger.experiment + + # Get a validation batch from the validation dataloader. + samples = next(iter(trainer.datamodule.val_dataloader())) + imgs, labels = samples + + imgs = imgs.to(device=pl_module.device) + logits = pl_module(imgs) + + mapping = pl_module.mapping + experiment.log( + { + f"OCR/{experiment.name}/{stage}": [ + wandb.Image( + img, + caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}", + ) + for img, pred, label in zip( + imgs[: self.num_samples], + logits[: self.num_samples], + labels[: self.num_samples], + ) + ] + } + ) + + def on_sanity_check_start( + self, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Sets ready attribute.""" + self.ready = False + + def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Start executing this callback only after all validation sanity checks end.""" + self.ready = True + + def on_validation_epoch_end( + self, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Logs predictions on validation epoch end.""" + self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module) + + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Logs predictions on train epoch end.""" + self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module) + + +@attr.s +class LogReconstuctedImages(Callback): + """Log reconstructions of images.""" + + num_samples: int = attr.ib(default=8) + ready: bool = attr.ib(default=True) + + def __attrs_pre_init__(self) -> None: + super().__init__() + + def _log_reconstruction( + self, stage: str, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Logs the reconstructions.""" + if not self.ready: + return None + + logger = get_wandb_logger(trainer) + experiment = logger.experiment + + # Get a validation batch from the validation dataloader. + samples = next(iter(trainer.datamodule.val_dataloader())) + imgs, _ = samples + + imgs = imgs.to(device=pl_module.device) + reconstructions = pl_module(imgs) + + experiment.log( + { + f"Reconstructions/{experiment.name}/{stage}": [ + [ + wandb.Image(img), + wandb.Image(rec), + ] + for img, rec in zip( + imgs[: self.num_samples], + reconstructions[: self.num_samples], + ) + ] + } + ) + + def on_sanity_check_start( + self, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Sets ready attribute.""" + self.ready = False + + def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Start executing this callback only after all validation sanity checks end.""" + self.ready = True + + def on_validation_epoch_end( + self, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Logs predictions on validation epoch end.""" + self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module) + + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Logs predictions on train epoch end.""" + self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module) diff --git a/training/conf/callbacks/wandb.yaml b/training/conf/callbacks/wandb.yaml index 2d56bfa..6eedb71 100644 --- a/training/conf/callbacks/wandb.yaml +++ b/training/conf/callbacks/wandb.yaml @@ -2,19 +2,19 @@ defaults: - default.yaml watch_model: - _target_: text_recognizer.callbacks.wandb_callbacks.WatchModel + _target_: callbacks.wandb_callbacks.WatchModel log: "all" log_freq: 100 upload_code_as_artifact: - _target_: text_recognizer.callbacks.wandb_callbacks.UploadCodeAsArtifact + _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact project_dir: ${work_dir}/text_recognizer upload_ckpts_as_artifact: - _target_: text_recognizer.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact + _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact ckpt_dir: "checkpoints/" upload_best_only: True log_text_predictions: - _target_: text_recognizer.callbacks.wandb_callbacks.LogTextPredictions + _target_: callbacks.wandb_callbacks.LogTextPredictions num_samples: 8 diff --git a/training/run.py b/training/run.py index 31da666..695a298 100644 --- a/training/run.py +++ b/training/run.py @@ -2,14 +2,14 @@ from typing import List, Optional, Type import hydra -import loguru.logger as log +from loguru import logger as log from omegaconf import DictConfig from pytorch_lightning import ( Callback, LightningDataModule, LightningModule, - Trainer, seed_everything, + Trainer, ) from pytorch_lightning.loggers import LightningLoggerBase from torch import nn @@ -67,7 +67,7 @@ def run(config: DictConfig) -> Optional[float]: log.info("Training network...") trainer.fit(model, datamodule=datamodule) - if config.test: + if config.test:lua/cfg/themes/dark.lua log.info("Testing network...") trainer.test(model, datamodule=datamodule) diff --git a/training/utils.py b/training/utils.py index 140d97e..88b72b7 100644 --- a/training/utils.py +++ b/training/utils.py @@ -3,8 +3,8 @@ from typing import Any, List, Type import warnings import hydra +from loguru import logger as log from omegaconf import DictConfig, OmegaConf -import loguru.logger as log from pytorch_lightning import ( Callback, LightningModule, |