diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-03 18:18:48 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-03 18:18:48 +0200 |
commit | bd4bd443f339e95007bfdabf3e060db720f4d4b9 (patch) | |
tree | e55cb3744904f7c2a0348b100c7e92a65e538a16 /training/callbacks | |
parent | 75801019981492eedf9280cb352eea3d8e99b65f (diff) |
Training working, multiple bug fixes
Diffstat (limited to 'training/callbacks')
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 83 |
1 files changed, 36 insertions, 47 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index 6379cc0..906531f 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -1,11 +1,10 @@ """Weights and Biases callbacks.""" from pathlib import Path -from typing import List -import attr import wandb from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.loggers import LoggerCollection, WandbLogger +from pytorch_lightning.utilities import rank_zero_only def get_wandb_logger(trainer: Trainer) -> WandbLogger: @@ -22,31 +21,27 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger: raise Exception("Weight and Biases logger not found for some reason...") -@attr.s class WatchModel(Callback): """Make W&B watch the model at the beginning of the run.""" - log: str = attr.ib(default="gradients") - log_freq: int = attr.ib(default=100) - - def __attrs_pre_init__(self) -> None: - super().__init__() + def __init__(self, log: str = "gradients", log_freq: int = 100) -> None: + self.log = log + self.log_freq = log_freq + @rank_zero_only def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """Watches model weights with wandb.""" logger = get_wandb_logger(trainer) logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) -@attr.s class UploadCodeAsArtifact(Callback): """Upload all *.py files to W&B as an artifact, at the beginning of the run.""" - project_dir: Path = attr.ib(converter=Path) - - def __attrs_pre_init__(self) -> None: - super().__init__() + def __init__(self, project_dir: str) -> None: + self.project_dir = Path(project_dir) + @rank_zero_only def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """Uploads project code as an artifact.""" logger = get_wandb_logger(trainer) @@ -58,16 +53,16 @@ class UploadCodeAsArtifact(Callback): experiment.use_artifact(artifact) -@attr.s -class UploadCheckpointAsArtifact(Callback): +class UploadCheckpointsAsArtifact(Callback): """Upload checkpoint to wandb as an artifact, at the end of a run.""" - ckpt_dir: Path = attr.ib(converter=Path) - upload_best_only: bool = attr.ib() - - def __attrs_pre_init__(self) -> None: - super().__init__() + def __init__( + self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False + ) -> None: + self.ckpt_dir = ckpt_dir + self.upload_best_only = upload_best_only + @rank_zero_only def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Uploads model checkpoint to W&B.""" logger = get_wandb_logger(trainer) @@ -83,15 +78,12 @@ class UploadCheckpointAsArtifact(Callback): experiment.use_artifact(ckpts) -@attr.s class LogTextPredictions(Callback): """Logs a validation batch with image to text transcription.""" - num_samples: int = attr.ib(default=8) - ready: bool = attr.ib(default=True) - - def __attrs_pre_init__(self) -> None: - super().__init__() + def __init__(self, num_samples: int = 8) -> None: + self.num_samples = num_samples + self.ready = False def _log_predictions( self, stage: str, trainer: Trainer, pl_module: LightningModule @@ -111,20 +103,20 @@ class LogTextPredictions(Callback): logits = pl_module(imgs) mapping = pl_module.mapping + columns = ["id", "image", "prediction", "truth"] + data = [ + [id, wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)] + for id, (img, pred, label) in enumerate( + zip( + imgs[: self.num_samples], + logits[: self.num_samples], + labels[: self.num_samples], + ) + ) + ] + experiment.log( - { - f"OCR/{experiment.name}/{stage}": [ - wandb.Image( - img, - caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}", - ) - for img, pred, label in zip( - imgs[: self.num_samples], - logits[: self.num_samples], - labels[: self.num_samples], - ) - ] - } + {f"OCR/{experiment.name}/{stage}": wandb.Table(data=data, columns=columns)} ) def on_sanity_check_start( @@ -143,20 +135,17 @@ class LogTextPredictions(Callback): """Logs predictions on validation epoch end.""" self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module) - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Logs predictions on train epoch end.""" self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module) -@attr.s class LogReconstuctedImages(Callback): """Log reconstructions of images.""" - num_samples: int = attr.ib(default=8) - ready: bool = attr.ib(default=True) - - def __attrs_pre_init__(self) -> None: - super().__init__() + def __init__(self, num_samples: int = 8) -> None: + self.num_samples = num_samples + self.ready = False def _log_reconstruction( self, stage: str, trainer: Trainer, pl_module: LightningModule @@ -202,6 +191,6 @@ class LogReconstuctedImages(Callback): """Logs predictions on validation epoch end.""" self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module) - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Logs predictions on train epoch end.""" self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module) |