diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-26 00:02:03 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-26 00:02:03 +0200 |
commit | adf411873c553a587495149b645a49f2bfb4e131 (patch) | |
tree | 7b45c9de160bf17a1d2c36fa00d6b34c6174807f | |
parent | a6c42e6f7cb70c1a06e46716f141c8f793a64e04 (diff) |
Add artifacts script
-rw-r--r-- | training/artifacts.py | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/training/artifacts.py b/training/artifacts.py new file mode 100644 index 0000000..9f970f8 --- /dev/null +++ b/training/artifacts.py @@ -0,0 +1,115 @@ +"""Fetches model artifacts from wandb.""" +from datetime import datetime +from pathlib import Path +import shutil +import sys +from typing import Optional, Union + +import click +from loguru import logger as log +import wandb +from wandb.apis.public import Run + +TRAINING_DIR = Path(__file__).parents[0].resolve() +ARTIFACTS_DIR = TRAINING_DIR.parent / "text_recognizer" / "artifacts" +RUNS_DIR = TRAINING_DIR / "logs" / "runs" + + +def _get_run_dir(run: Run) -> Optional[Path]: + created_at = datetime.fromisoformat(run.created_at).astimezone() + date = created_at.date() + hour = (created_at + created_at.utcoffset()).hour + runs = list((RUNS_DIR / f"{date}").glob(f"{hour}-*")) + if not runs: + return None + return runs[0] + + +def _get_best_weights(run_dir: Path) -> Optional[Path]: + checkpoints = list(run_dir.glob("**/epoch=*.ckpt")) + if not checkpoints: + return None + return checkpoints[0] + + +def _copy_config(run_dir: Path, dst_dir: Path) -> None: + log.info(f"Copying config to artifact folders ({dst_dir})") + shutil.copyfile(src=run_dir / "config.yaml", dst=dst_dir / "config.yaml") + + +def _copy_checkpoint(checkpoint: Path, dst_dir: Path) -> None: + """Copy model checkpoint from local directory.""" + log.info(f"Copying best run ({checkpoint}) to artifact folders ({dst_dir})") + shutil.copyfile(src=checkpoint, dst=dst_dir / "model.pt") + + +def save_model(run: Run, tag: str) -> None: + """Save model to artifacts.""" + dst_dir = ARTIFACTS_DIR / f"{tag}_text_recognizer" + dst_dir.mkdir(parents=True, exist_ok=True) + run_dir = _get_run_dir(run) + if not run_dir: + log.error("Could not find experiment locally!") + best_weights = _get_best_weights(run_dir) + if not best_weights: + log.error("Could not find checkpoint locally!") + _copy_config(run_dir, dst_dir) + _copy_checkpoint(best_weights, dst_dir) + log.info("Successfully moved model and config to artifacts directory") + # TODO: be able to download from w&b + + +def find_best_model(entity: str, project: str, tag: str, metric: str, mode: str) -> Run: + """Find the best model on wandb.""" + if mode == "min": + default_metric_value = sys.maxsize + sort_reverse = False + else: + default_metric_value = 0 + sort_reverse = True + api = wandb.Api() + runs = api.runs(f"{entity}/{project}", filters={"tags": {"$in": [tag]}}) + runs = sorted( + runs, + key=lambda run: run.summary.get(metric, default_metric_value), + reverse=sort_reverse, + ) + best_run = runs[0] + summary = best_run.summary + log.info( + f"Best run is ({best_run.name}, {best_run.id}) picked from {len(runs)} runs with the following metric" + ) + log.info( + f"{metric}: {summary[metric]}" + # , {metric.replace('val', 'test')}: {summary[metric.replace('val', 'test')]}" + ) + return best_run + + +@click.command() +@click.option("--entity", type=str, default="aktersnurra", help="Name of the author") +@click.option( + "--project", type=str, default="text-recognizer", help="The wandb project name" +) +@click.option( + "--tag", + type=click.Choice(["paragraphs", "lines"]), + default="paragraphs", + help="Tag to filter by", +) +@click.option( + "--metric", type=str, default="val_loss", help="Which metric to filter on" +) +@click.option( + "--mode", + type=click.Choice(["min", "max"]), + default="min", + help="Min or max value of metric", +) +def main(entity: str, project: str, tag: str, metric: str, mode: str) -> None: + best_run = find_best_model(entity, project, tag, metric, mode) + save_model(best_run, tag) + + +if __name__ == "__main__": + main() |