1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
|
"""Fetches model artifacts from wandb."""
from datetime import datetime
from pathlib import Path
import shutil
import sys
from typing import Optional, Union
import click
from loguru import logger as log
import wandb
from wandb.apis.public import Run
TRAINING_DIR = Path(__file__).parents[0].resolve()
ARTIFACTS_DIR = TRAINING_DIR.parent / "text_recognizer" / "artifacts"
RUNS_DIR = TRAINING_DIR / "logs" / "runs"
def _get_run_dir(run: Run) -> Optional[Path]:
created_at = datetime.fromisoformat(run.created_at).astimezone()
date = created_at.date()
hour = (created_at + created_at.utcoffset()).hour
runs = list((RUNS_DIR / f"{date}").glob(f"{hour}-*"))
if not runs:
return None
return runs[0]
def _get_best_weights(run_dir: Path) -> Optional[Path]:
checkpoints = list(run_dir.glob("**/epoch=*.ckpt"))
if not checkpoints:
return None
return checkpoints[0]
def _copy_config(run_dir: Path, dst_dir: Path) -> None:
log.info(f"Copying config to artifact folders ({dst_dir})")
shutil.copyfile(src=run_dir / "config.yaml", dst=dst_dir / "config.yaml")
def _copy_checkpoint(checkpoint: Path, dst_dir: Path) -> None:
"""Copy model checkpoint from local directory."""
log.info(f"Copying best run ({checkpoint}) to artifact folders ({dst_dir})")
shutil.copyfile(src=checkpoint, dst=dst_dir / "model.pt")
def save_model(run: Run, tag: str) -> None:
"""Save model to artifacts."""
dst_dir = ARTIFACTS_DIR / f"{tag}_text_recognizer"
dst_dir.mkdir(parents=True, exist_ok=True)
run_dir = _get_run_dir(run)
if not run_dir:
log.error("Could not find experiment locally!")
best_weights = _get_best_weights(run_dir)
if not best_weights:
log.error("Could not find checkpoint locally!")
_copy_config(run_dir, dst_dir)
_copy_checkpoint(best_weights, dst_dir)
log.info("Successfully moved model and config to artifacts directory")
# TODO: be able to download from w&b
def find_best_run(entity: str, project: str, tag: str, metric: str, mode: str) -> Run:
"""Find the best model on wandb."""
if mode == "min":
default_metric_value = sys.maxsize
sort_reverse = False
else:
default_metric_value = 0
sort_reverse = True
api = wandb.Api()
runs = api.runs(f"{entity}/{project}", filters={"tags": {"$in": [tag]}})
runs = sorted(
runs,
key=lambda run: run.summary.get(metric, default_metric_value),
reverse=sort_reverse,
)
best_run = runs[0]
summary = best_run.summary
log.info(
f"Best run is ({best_run.name}, {best_run.id}) picked from {len(runs)} runs with the following metric"
)
log.info(
f"{metric}: {summary[metric]}"
# , {metric.replace('val', 'test')}: {summary[metric.replace('val', 'test')]}"
)
return best_run
@click.command()
@click.option("--entity", type=str, default="aktersnurra", help="Name of the author")
@click.option(
"--project", type=str, default="text-recognizer", help="The wandb project name"
)
@click.option(
"--tag",
type=click.Choice(["paragraphs", "lines"]),
default="paragraphs",
help="Tag to filter by",
)
@click.option(
"--metric", type=str, default="val_loss", help="Which metric to filter on"
)
@click.option(
"--mode",
type=click.Choice(["min", "max"]),
default="min",
help="Min or max value of metric",
)
def main(entity: str, project: str, tag: str, metric: str, mode: str) -> None:
best_run = find_best_run(entity, project, tag, metric, mode)
save_model(best_run, tag)
if __name__ == "__main__":
main()
|