1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
|
"""Util functions for downloading datasets."""
import hashlib
from pathlib import Path
from typing import Dict, List, Optional
from urllib.request import urlretrieve
from loguru import logger
from tqdm import tqdm
def _compute_sha256(filename: Path) -> str:
"""Returns the SHA256 checksum of a file."""
with filename.open(mode="rb") as f:
return hashlib.sha256(f.read()).hexdigest()
class TqdmUpTo(tqdm):
"""TQDM progress bar when downloading files.
From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
"""
def update_to(
self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
) -> None:
"""Updates the progress bar.
Args:
blocks (int): Number of blocks transferred so far. Defaults to 1.
block_size (int): Size of each block, in tqdm units. Defaults to 1.
total_size (Optional[int]): Total size in tqdm units. Defaults to None.
"""
if total_size is not None:
self.total = total_size # pylint: disable=attribute-defined-outside-init
self.update(blocks * block_size - self.n)
def _download_url(url: str, filename: str) -> None:
"""Downloads a file from url to filename, with a progress bar."""
with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec
def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]:
"""Downloads dataset using a metadata file.
Args:
metadata (Dict): A metadata file of the dataset.
dl_dir (Path): Download directory for the dataset.
Returns:
Optional[Path]: Returns filename if dataset is downloaded, None if it already
exists.
Raises:
ValueError: If the SHA-256 value is not the same between the dataset and
the metadata file.
"""
dl_dir.mkdir(parents=True, exist_ok=True)
filename = dl_dir / metadata["filename"]
if filename.exists():
return
logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...")
_download_url(metadata["url"], filename)
logger.info("Computing the SHA-256...")
sha256 = _compute_sha256(filename)
if sha256 != metadata["sha256"]:
raise ValueError(
"Downloaded data file SHA-256 does not match that listed in metadata document."
)
return filename
|