From 72befa6d8cc4c7ecf698512a97424641ee81725a Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Thu, 8 Jul 2021 22:26:15 +0200
Subject: Move callbacks to training folder, refactor

---
 training/callbacks/__init__.py        |   1 +
 training/callbacks/wandb_callbacks.py | 211 ++++++++++++++++++++++++++++++++++
 training/conf/callbacks/wandb.yaml    |   8 +-
 training/run.py                       |   6 +-
 training/utils.py                     |   2 +-
 5 files changed, 220 insertions(+), 8 deletions(-)
 create mode 100644 training/callbacks/__init__.py
 create mode 100644 training/callbacks/wandb_callbacks.py

(limited to 'training')

diff --git a/training/callbacks/__init__.py b/training/callbacks/__init__.py
new file mode 100644
index 0000000..82d8ce3
--- /dev/null
+++ b/training/callbacks/__init__.py
@@ -0,0 +1 @@
+"""Module for PyTorch Lightning callbacks."""
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
new file mode 100644
index 0000000..d9d81f6
--- /dev/null
+++ b/training/callbacks/wandb_callbacks.py
@@ -0,0 +1,211 @@
+"""Weights and Biases callbacks."""
+from pathlib import Path
+from typing import List
+
+import attr
+import wandb
+from pytorch_lightning import Callback, LightningModule, Trainer
+from pytorch_lightning.loggers import LoggerCollection, WandbLogger
+
+
+def get_wandb_logger(trainer: Trainer) -> WandbLogger:
+    """Safely get W&B logger from Trainer."""
+
+    if isinstance(trainer.logger, WandbLogger):
+        return trainer.logger
+
+    if isinstance(trainer.logger, LoggerCollection):
+        for logger in trainer.logger:
+            if isinstance(logger, WandbLogger):
+                return logger
+
+    raise Exception("Weight and Biases logger not found for some reason...")
+
+
+@attr.s
+class WatchModel(Callback):
+    """Make W&B watch the model at the beginning of the run."""
+
+    log: str = attr.ib(default="gradients")
+    log_freq: int = attr.ib(default=100)
+
+    def __attrs_pre_init__(self) -> None:
+        super().__init__()
+
+    def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
+        """Watches model weights with wandb."""
+        logger = get_wandb_logger(trainer)
+        logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq)
+
+
+@attr.s
+class UploadCodeAsArtifact(Callback):
+    """Upload all *.py files to W&B as an artifact, at the beginning of the run."""
+
+    project_dir: Path = attr.ib(converter=Path)
+
+    def __attrs_pre_init__(self) -> None:
+        super().__init__()
+
+    def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
+        """Uploads project code as an artifact."""
+        logger = get_wandb_logger(trainer)
+        experiment = logger.experiment
+        artifact = wandb.Artifact("project-source", type="code")
+        for filepath in self.project_dir.glob("**/*.py"):
+            artifact.add_file(filepath)
+
+        experiment.use_artifact(artifact)
+
+
+@attr.s
+class UploadCheckpointAsArtifact(Callback):
+    """Upload checkpoint to wandb as an artifact, at the end of a run."""
+
+    ckpt_dir: Path = attr.ib(converter=Path)
+    upload_best_only: bool = attr.ib()
+
+    def __attrs_pre_init__(self) -> None:
+        super().__init__()
+
+    def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+        """Uploads model checkpoint to W&B."""
+        logger = get_wandb_logger(trainer)
+        experiment = logger.experiment
+        ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints")
+
+        if self.upload_best_only:
+            ckpts.add_file(trainer.checkpoint_callback.best_model_path)
+        else:
+            for ckpt in (self.ckpt_dir).glob("**/*.ckpt"):
+                ckpts.add_file(ckpt)
+
+        experiment.use_artifact(ckpts)
+
+
+@attr.s
+class LogTextPredictions(Callback):
+    """Logs a validation batch with image to text transcription."""
+
+    num_samples: int = attr.ib(default=8)
+    ready: bool = attr.ib(default=True)
+
+    def __attrs_pre_init__(self) -> None:
+        super().__init__()
+
+    def _log_predictions(
+        stage: str, trainer: Trainer, pl_module: LightningModule
+    ) -> None:
+        """Logs the predicted text contained in the images."""
+        if not self.ready:
+            return None
+
+        logger = get_wandb_logger(trainer)
+        experiment = logger.experiment
+
+        # Get a validation batch from the validation dataloader.
+        samples = next(iter(trainer.datamodule.val_dataloader()))
+        imgs, labels = samples
+
+        imgs = imgs.to(device=pl_module.device)
+        logits = pl_module(imgs)
+
+        mapping = pl_module.mapping
+        experiment.log(
+            {
+                f"OCR/{experiment.name}/{stage}": [
+                    wandb.Image(
+                        img,
+                        caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}",
+                    )
+                    for img, pred, label in zip(
+                        imgs[: self.num_samples],
+                        logits[: self.num_samples],
+                        labels[: self.num_samples],
+                    )
+                ]
+            }
+        )
+
+    def on_sanity_check_start(
+        self, trainer: Trainer, pl_module: LightningModule
+    ) -> None:
+        """Sets ready attribute."""
+        self.ready = False
+
+    def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+        """Start executing this callback only after all validation sanity checks end."""
+        self.ready = True
+
+    def on_validation_epoch_end(
+        self, trainer: Trainer, pl_module: LightningModule
+    ) -> None:
+        """Logs predictions on validation epoch end."""
+        self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module)
+
+    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+        """Logs predictions on train epoch end."""
+        self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module)
+
+
+@attr.s
+class LogReconstuctedImages(Callback):
+    """Log reconstructions of images."""
+
+    num_samples: int = attr.ib(default=8)
+    ready: bool = attr.ib(default=True)
+
+    def __attrs_pre_init__(self) -> None:
+        super().__init__()
+
+    def _log_reconstruction(
+        self, stage: str, trainer: Trainer, pl_module: LightningModule
+    ) -> None:
+        """Logs the reconstructions."""
+        if not self.ready:
+            return None
+
+        logger = get_wandb_logger(trainer)
+        experiment = logger.experiment
+
+        # Get a validation batch from the validation dataloader.
+        samples = next(iter(trainer.datamodule.val_dataloader()))
+        imgs, _ = samples
+
+        imgs = imgs.to(device=pl_module.device)
+        reconstructions = pl_module(imgs)
+
+        experiment.log(
+            {
+                f"Reconstructions/{experiment.name}/{stage}": [
+                    [
+                        wandb.Image(img),
+                        wandb.Image(rec),
+                    ]
+                    for img, rec in zip(
+                        imgs[: self.num_samples],
+                        reconstructions[: self.num_samples],
+                    )
+                ]
+            }
+        )
+
+    def on_sanity_check_start(
+        self, trainer: Trainer, pl_module: LightningModule
+    ) -> None:
+        """Sets ready attribute."""
+        self.ready = False
+
+    def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+        """Start executing this callback only after all validation sanity checks end."""
+        self.ready = True
+
+    def on_validation_epoch_end(
+        self, trainer: Trainer, pl_module: LightningModule
+    ) -> None:
+        """Logs predictions on validation epoch end."""
+        self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module)
+
+    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+        """Logs predictions on train epoch end."""
+        self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module)
diff --git a/training/conf/callbacks/wandb.yaml b/training/conf/callbacks/wandb.yaml
index 2d56bfa..6eedb71 100644
--- a/training/conf/callbacks/wandb.yaml
+++ b/training/conf/callbacks/wandb.yaml
@@ -2,19 +2,19 @@ defaults:
   - default.yaml
 
 watch_model:
-  _target_: text_recognizer.callbacks.wandb_callbacks.WatchModel
+  _target_: callbacks.wandb_callbacks.WatchModel
   log: "all"
   log_freq: 100
 
 upload_code_as_artifact:
-  _target_: text_recognizer.callbacks.wandb_callbacks.UploadCodeAsArtifact
+  _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact
   project_dir: ${work_dir}/text_recognizer
 
 upload_ckpts_as_artifact:
-  _target_: text_recognizer.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
+  _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
   ckpt_dir: "checkpoints/"
   upload_best_only: True
 
 log_text_predictions:
-  _target_: text_recognizer.callbacks.wandb_callbacks.LogTextPredictions
+  _target_: callbacks.wandb_callbacks.LogTextPredictions
   num_samples: 8
diff --git a/training/run.py b/training/run.py
index 31da666..695a298 100644
--- a/training/run.py
+++ b/training/run.py
@@ -2,14 +2,14 @@
 from typing import List, Optional, Type
 
 import hydra
-import loguru.logger as log
+from loguru import logger as log
 from omegaconf import DictConfig
 from pytorch_lightning import (
     Callback,
     LightningDataModule,
     LightningModule,
-    Trainer,
     seed_everything,
+    Trainer,
 )
 from pytorch_lightning.loggers import LightningLoggerBase
 from torch import nn
@@ -67,7 +67,7 @@ def run(config: DictConfig) -> Optional[float]:
         log.info("Training network...")
         trainer.fit(model, datamodule=datamodule)
 
-    if config.test:
+    if config.test:lua/cfg/themes/dark.lua
         log.info("Testing network...")
         trainer.test(model, datamodule=datamodule)
 
diff --git a/training/utils.py b/training/utils.py
index 140d97e..88b72b7 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -3,8 +3,8 @@ from typing import Any, List, Type
 import warnings
 
 import hydra
+from loguru import logger as log
 from omegaconf import DictConfig, OmegaConf
-import loguru.logger as log
 from pytorch_lightning import (
     Callback,
     LightningModule,
-- 
cgit v1.2.3-70-g09d2