summaryrefslogtreecommitdiff
path: root/text_recognizer/callbacks/wandb_callbacks.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/callbacks/wandb_callbacks.py')
-rw-r--r--text_recognizer/callbacks/wandb_callbacks.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/text_recognizer/callbacks/wandb_callbacks.py b/text_recognizer/callbacks/wandb_callbacks.py
index 3936aaf..900c3b1 100644
--- a/text_recognizer/callbacks/wandb_callbacks.py
+++ b/text_recognizer/callbacks/wandb_callbacks.py
@@ -29,6 +29,9 @@ class WatchModel(Callback):
log: str = attr.ib(default="gradients")
log_freq: int = attr.ib(default=100)
+ def __attrs_pre_init__(self):
+ super().__init__()
+
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Watches model weights with wandb."""
logger = get_wandb_logger(trainer)
@@ -41,6 +44,9 @@ class UploadCodeAsArtifact(Callback):
project_dir: Path = attr.ib(converter=Path)
+ def __attrs_pre_init__(self):
+ super().__init__()
+
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Uploads project code as an artifact."""
logger = get_wandb_logger(trainer)
@@ -59,6 +65,9 @@ class UploadCheckpointAsArtifact(Callback):
ckpt_dir: Path = attr.ib(converter=Path)
upload_best_only: bool = attr.ib()
+ def __attrs_pre_init__(self):
+ super().__init__()
+
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Uploads model checkpoint to W&B."""
logger = get_wandb_logger(trainer)