summaryrefslogtreecommitdiff
path: root/text_recognizer/datasets/download_utils.py
blob: e3dc68c78676ad9e64882901380bb3b40323d7a1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Util functions for downloading datasets."""
import hashlib
from pathlib import Path
from typing import Dict, List, Optional
from urllib.request import urlretrieve

from loguru import logger
from tqdm import tqdm


def _compute_sha256(filename: Path) -> str:
    """Returns the SHA256 checksum of a file."""
    with filename.open(mode="rb") as f:
        return hashlib.sha256(f.read()).hexdigest()


class TqdmUpTo(tqdm):
    """TQDM progress bar when downloading files.

    From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py

    """

    def update_to(
        self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
    ) -> None:
        """Updates the progress bar.

        Args:
            blocks (int): Number of blocks transferred so far. Defaults to 1.
            block_size (int): Size of each block, in tqdm units. Defaults to 1.
            total_size (Optional[int]): Total size in tqdm units. Defaults to None.
        """
        if total_size is not None:
            self.total = total_size  # pylint: disable=attribute-defined-outside-init
        self.update(blocks * block_size - self.n)


def _download_url(url: str, filename: str) -> None:
    """Downloads a file from url to filename, with a progress bar."""
    with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
        urlretrieve(url, filename, reporthook=t.update_to, data=None)  # nosec


def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]:
    """Downloads dataset using a metadata file.

    Args:
        metadata (Dict): A metadata file of the dataset.
        dl_dir (Path): Download directory for the dataset.

    Returns:
        Optional[Path]: Returns filename if dataset is downloaded, None if it already
            exists.

    Raises:
        ValueError: If the SHA-256 value is not the same between the dataset and
            the metadata file.

    """
    dl_dir.mkdir(parents=True, exist_ok=True)
    filename = dl_dir / metadata["filename"]
    if filename.exists():
        return
    logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...")
    _download_url(metadata["url"], filename)
    logger.info("Computing the SHA-256...")
    sha256 = _compute_sha256(filename)
    if sha256 != metadata["sha256"]:
        raise ValueError(
            "Downloaded data file SHA-256 does not match that listed in metadata document."
        )
    return filename