summaryrefslogtreecommitdiff
path: root/text_recognizer/datasets/download_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/datasets/download_utils.py')
-rw-r--r--text_recognizer/datasets/download_utils.py73
1 files changed, 73 insertions, 0 deletions
diff --git a/text_recognizer/datasets/download_utils.py b/text_recognizer/datasets/download_utils.py
new file mode 100644
index 0000000..7a2cab8
--- /dev/null
+++ b/text_recognizer/datasets/download_utils.py
@@ -0,0 +1,73 @@
+"""Util functions for downloading datasets."""
+import hashlib
+from pathlib import Path
+from typing import Dict, List, Optional
+from urllib.request import urlretrieve
+
+from loguru import logger
+from tqdm import tqdm
+
+
+def _compute_sha256(filename: Path) -> str:
+ """Returns the SHA256 checksum of a file."""
+ with filename.open(mode="rb") as f:
+ return hashlib.sha256(f.read()).hexdigest()
+
+
+class TqdmUpTo(tqdm):
+ """TQDM progress bar when downloading files.
+
+ From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
+
+ """
+
+ def update_to(
+ self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
+ ) -> None:
+ """Updates the progress bar.
+
+ Args:
+ blocks (int): Number of blocks transferred so far. Defaults to 1.
+ block_size (int): Size of each block, in tqdm units. Defaults to 1.
+ total_size (Optional[int]): Total size in tqdm units. Defaults to None.
+ """
+ if total_size is not None:
+ self.total = total_size # pylint: disable=attribute-defined-outside-init
+ self.update(blocks * block_size - self.n)
+
+
+def _download_url(url: str, filename: str) -> None:
+ """Downloads a file from url to filename, with a progress bar."""
+ with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
+ urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec
+
+
+def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]:
+ """Downloads dataset using a metadata file.
+
+ Args:
+ metadata (Dict): A metadata file of the dataset.
+ dl_dir (Path): Download directory for the dataset.
+
+ Returns:
+ Optional[Path]: Returns filename if dataset is downloaded, None if it already
+ exists.
+
+ Raises:
+ ValueError: If the SHA-256 value is not the same between the dataset and
+ the metadata file.
+
+ """
+ dl_dir.mkdir(parents=True, exist_ok=True)
+ filename = dl_dir / metadata["filename"]
+ if filename.exists():
+ return
+ logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...")
+ _download_url(metadata["url"], filename)
+ logger.info("Computing the SHA-256...")
+ sha256 = _compute_sha256(filename)
+ if sha256 != metadata["sha256"]:
+ raise ValueError(
+ "Downloaded data file SHA-256 does not match that listed in metadata document."
+ )
+ return filename