summaryrefslogtreecommitdiff
path: root/training/artifacts.py
blob: a059833736afb8a048c8f0dafb2b1501ecacb3e4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Fetches model artifacts from wandb."""
from datetime import datetime
from pathlib import Path
import shutil
import sys
from typing import Optional, Union

import click
from loguru import logger as log
import wandb
from wandb.apis.public import Run

TRAINING_DIR = Path(__file__).parents[0].resolve()
ARTIFACTS_DIR = TRAINING_DIR.parent / "text_recognizer" / "artifacts"
RUNS_DIR = TRAINING_DIR / "logs" / "runs"


def _get_run_dir(run: Run) -> Optional[Path]:
    created_at = datetime.fromisoformat(run.created_at).astimezone()
    date = created_at.date()
    hour = (created_at + created_at.utcoffset()).hour
    runs = list((RUNS_DIR / f"{date}").glob(f"{hour}-*"))
    if not runs:
        return None
    return runs[0]


def _get_best_weights(run_dir: Path) -> Optional[Path]:
    checkpoints = list(run_dir.glob("**/epoch=*.ckpt"))
    if not checkpoints:
        return None
    return checkpoints[0]


def _copy_config(run_dir: Path, dst_dir: Path) -> None:
    log.info(f"Copying config to artifact folders ({dst_dir})")
    shutil.copyfile(src=run_dir / "config.yaml", dst=dst_dir / "config.yaml")


def _copy_checkpoint(checkpoint: Path, dst_dir: Path) -> None:
    """Copy model checkpoint from local directory."""
    log.info(f"Copying best run ({checkpoint}) to artifact folders ({dst_dir})")
    shutil.copyfile(src=checkpoint, dst=dst_dir / "model.pt")


def save_model(run: Run, tag: str) -> None:
    """Save model to artifacts."""
    dst_dir = ARTIFACTS_DIR / f"{tag}_text_recognizer"
    dst_dir.mkdir(parents=True, exist_ok=True)
    run_dir = _get_run_dir(run)
    if not run_dir:
        log.error("Could not find experiment locally!")
    best_weights = _get_best_weights(run_dir)
    if not best_weights:
        log.error("Could not find checkpoint locally!")
    _copy_config(run_dir, dst_dir)
    _copy_checkpoint(best_weights, dst_dir)
    log.info("Successfully moved model and config to artifacts directory")
    # TODO: be able to download from w&b


def find_best_run(entity: str, project: str, tag: str, metric: str, mode: str) -> Run:
    """Find the best model on wandb."""
    if mode == "min":
        default_metric_value = sys.maxsize
        sort_reverse = False
    else:
        default_metric_value = 0
        sort_reverse = True
    api = wandb.Api()
    runs = api.runs(f"{entity}/{project}", filters={"tags": {"$in": [tag]}})
    runs = sorted(
        runs,
        key=lambda run: run.summary.get(metric, default_metric_value),
        reverse=sort_reverse,
    )
    best_run = runs[0]
    summary = best_run.summary
    log.info(
        f"Best run is ({best_run.name}, {best_run.id}) picked from {len(runs)} runs with the following metric"
    )
    log.info(
        f"{metric}: {summary[metric]}"
        # , {metric.replace('val', 'test')}: {summary[metric.replace('val', 'test')]}"
    )
    return best_run


@click.command()
@click.option("--entity", type=str, default="aktersnurra", help="Name of the author")
@click.option(
    "--project", type=str, default="text-recognizer", help="The wandb project name"
)
@click.option(
    "--tag",
    type=click.Choice(["paragraphs", "lines"]),
    default="paragraphs",
    help="Tag to filter by",
)
@click.option(
    "--metric", type=str, default="val_loss", help="Which metric to filter on"
)
@click.option(
    "--mode",
    type=click.Choice(["min", "max"]),
    default="min",
    help="Min or max value of metric",
)
def main(entity: str, project: str, tag: str, metric: str, mode: str) -> None:
    best_run = find_best_run(entity, project, tag, metric, mode)
    save_model(best_run, tag)


if __name__ == "__main__":
    main()