summaryrefslogtreecommitdiff
path: root/training/callbacks
diff options
context:
space:
mode:
Diffstat (limited to 'training/callbacks')
-rw-r--r--training/callbacks/wandb_callbacks.py83
1 files changed, 36 insertions, 47 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index 6379cc0..906531f 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -1,11 +1,10 @@
"""Weights and Biases callbacks."""
from pathlib import Path
-from typing import List
-import attr
import wandb
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.loggers import LoggerCollection, WandbLogger
+from pytorch_lightning.utilities import rank_zero_only
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
@@ -22,31 +21,27 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger:
raise Exception("Weight and Biases logger not found for some reason...")
-@attr.s
class WatchModel(Callback):
"""Make W&B watch the model at the beginning of the run."""
- log: str = attr.ib(default="gradients")
- log_freq: int = attr.ib(default=100)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
+ def __init__(self, log: str = "gradients", log_freq: int = 100) -> None:
+ self.log = log
+ self.log_freq = log_freq
+ @rank_zero_only
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Watches model weights with wandb."""
logger = get_wandb_logger(trainer)
logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq)
-@attr.s
class UploadCodeAsArtifact(Callback):
"""Upload all *.py files to W&B as an artifact, at the beginning of the run."""
- project_dir: Path = attr.ib(converter=Path)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
+ def __init__(self, project_dir: str) -> None:
+ self.project_dir = Path(project_dir)
+ @rank_zero_only
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Uploads project code as an artifact."""
logger = get_wandb_logger(trainer)
@@ -58,16 +53,16 @@ class UploadCodeAsArtifact(Callback):
experiment.use_artifact(artifact)
-@attr.s
-class UploadCheckpointAsArtifact(Callback):
+class UploadCheckpointsAsArtifact(Callback):
"""Upload checkpoint to wandb as an artifact, at the end of a run."""
- ckpt_dir: Path = attr.ib(converter=Path)
- upload_best_only: bool = attr.ib()
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
+ def __init__(
+ self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False
+ ) -> None:
+ self.ckpt_dir = ckpt_dir
+ self.upload_best_only = upload_best_only
+ @rank_zero_only
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Uploads model checkpoint to W&B."""
logger = get_wandb_logger(trainer)
@@ -83,15 +78,12 @@ class UploadCheckpointAsArtifact(Callback):
experiment.use_artifact(ckpts)
-@attr.s
class LogTextPredictions(Callback):
"""Logs a validation batch with image to text transcription."""
- num_samples: int = attr.ib(default=8)
- ready: bool = attr.ib(default=True)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
+ def __init__(self, num_samples: int = 8) -> None:
+ self.num_samples = num_samples
+ self.ready = False
def _log_predictions(
self, stage: str, trainer: Trainer, pl_module: LightningModule
@@ -111,20 +103,20 @@ class LogTextPredictions(Callback):
logits = pl_module(imgs)
mapping = pl_module.mapping
+ columns = ["id", "image", "prediction", "truth"]
+ data = [
+ [id, wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)]
+ for id, (img, pred, label) in enumerate(
+ zip(
+ imgs[: self.num_samples],
+ logits[: self.num_samples],
+ labels[: self.num_samples],
+ )
+ )
+ ]
+
experiment.log(
- {
- f"OCR/{experiment.name}/{stage}": [
- wandb.Image(
- img,
- caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}",
- )
- for img, pred, label in zip(
- imgs[: self.num_samples],
- logits[: self.num_samples],
- labels[: self.num_samples],
- )
- ]
- }
+ {f"OCR/{experiment.name}/{stage}": wandb.Table(data=data, columns=columns)}
)
def on_sanity_check_start(
@@ -143,20 +135,17 @@ class LogTextPredictions(Callback):
"""Logs predictions on validation epoch end."""
self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module)
- def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Logs predictions on train epoch end."""
self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module)
-@attr.s
class LogReconstuctedImages(Callback):
"""Log reconstructions of images."""
- num_samples: int = attr.ib(default=8)
- ready: bool = attr.ib(default=True)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
+ def __init__(self, num_samples: int = 8) -> None:
+ self.num_samples = num_samples
+ self.ready = False
def _log_reconstruction(
self, stage: str, trainer: Trainer, pl_module: LightningModule
@@ -202,6 +191,6 @@ class LogReconstuctedImages(Callback):
"""Logs predictions on validation epoch end."""
self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module)
- def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Logs predictions on train epoch end."""
self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module)