summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-26 00:02:03 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-26 00:02:03 +0200
commitadf411873c553a587495149b645a49f2bfb4e131 (patch)
tree7b45c9de160bf17a1d2c36fa00d6b34c6174807f /training
parenta6c42e6f7cb70c1a06e46716f141c8f793a64e04 (diff)
Add artifacts script
Diffstat (limited to 'training')
-rw-r--r--training/artifacts.py115
1 files changed, 115 insertions, 0 deletions
diff --git a/training/artifacts.py b/training/artifacts.py
new file mode 100644
index 0000000..9f970f8
--- /dev/null
+++ b/training/artifacts.py
@@ -0,0 +1,115 @@
+"""Fetches model artifacts from wandb."""
+from datetime import datetime
+from pathlib import Path
+import shutil
+import sys
+from typing import Optional, Union
+
+import click
+from loguru import logger as log
+import wandb
+from wandb.apis.public import Run
+
+TRAINING_DIR = Path(__file__).parents[0].resolve()
+ARTIFACTS_DIR = TRAINING_DIR.parent / "text_recognizer" / "artifacts"
+RUNS_DIR = TRAINING_DIR / "logs" / "runs"
+
+
+def _get_run_dir(run: Run) -> Optional[Path]:
+ created_at = datetime.fromisoformat(run.created_at).astimezone()
+ date = created_at.date()
+ hour = (created_at + created_at.utcoffset()).hour
+ runs = list((RUNS_DIR / f"{date}").glob(f"{hour}-*"))
+ if not runs:
+ return None
+ return runs[0]
+
+
+def _get_best_weights(run_dir: Path) -> Optional[Path]:
+ checkpoints = list(run_dir.glob("**/epoch=*.ckpt"))
+ if not checkpoints:
+ return None
+ return checkpoints[0]
+
+
+def _copy_config(run_dir: Path, dst_dir: Path) -> None:
+ log.info(f"Copying config to artifact folders ({dst_dir})")
+ shutil.copyfile(src=run_dir / "config.yaml", dst=dst_dir / "config.yaml")
+
+
+def _copy_checkpoint(checkpoint: Path, dst_dir: Path) -> None:
+ """Copy model checkpoint from local directory."""
+ log.info(f"Copying best run ({checkpoint}) to artifact folders ({dst_dir})")
+ shutil.copyfile(src=checkpoint, dst=dst_dir / "model.pt")
+
+
+def save_model(run: Run, tag: str) -> None:
+ """Save model to artifacts."""
+ dst_dir = ARTIFACTS_DIR / f"{tag}_text_recognizer"
+ dst_dir.mkdir(parents=True, exist_ok=True)
+ run_dir = _get_run_dir(run)
+ if not run_dir:
+ log.error("Could not find experiment locally!")
+ best_weights = _get_best_weights(run_dir)
+ if not best_weights:
+ log.error("Could not find checkpoint locally!")
+ _copy_config(run_dir, dst_dir)
+ _copy_checkpoint(best_weights, dst_dir)
+ log.info("Successfully moved model and config to artifacts directory")
+ # TODO: be able to download from w&b
+
+
+def find_best_model(entity: str, project: str, tag: str, metric: str, mode: str) -> Run:
+ """Find the best model on wandb."""
+ if mode == "min":
+ default_metric_value = sys.maxsize
+ sort_reverse = False
+ else:
+ default_metric_value = 0
+ sort_reverse = True
+ api = wandb.Api()
+ runs = api.runs(f"{entity}/{project}", filters={"tags": {"$in": [tag]}})
+ runs = sorted(
+ runs,
+ key=lambda run: run.summary.get(metric, default_metric_value),
+ reverse=sort_reverse,
+ )
+ best_run = runs[0]
+ summary = best_run.summary
+ log.info(
+ f"Best run is ({best_run.name}, {best_run.id}) picked from {len(runs)} runs with the following metric"
+ )
+ log.info(
+ f"{metric}: {summary[metric]}"
+ # , {metric.replace('val', 'test')}: {summary[metric.replace('val', 'test')]}"
+ )
+ return best_run
+
+
+@click.command()
+@click.option("--entity", type=str, default="aktersnurra", help="Name of the author")
+@click.option(
+ "--project", type=str, default="text-recognizer", help="The wandb project name"
+)
+@click.option(
+ "--tag",
+ type=click.Choice(["paragraphs", "lines"]),
+ default="paragraphs",
+ help="Tag to filter by",
+)
+@click.option(
+ "--metric", type=str, default="val_loss", help="Which metric to filter on"
+)
+@click.option(
+ "--mode",
+ type=click.Choice(["min", "max"]),
+ default="min",
+ help="Min or max value of metric",
+)
+def main(entity: str, project: str, tag: str, metric: str, mode: str) -> None:
+ best_run = find_best_model(entity, project, tag, metric, mode)
+ save_model(best_run, tag)
+
+
+if __name__ == "__main__":
+ main()