diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-08 23:14:23 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-08 23:14:23 +0200 |
commit | e1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (patch) | |
tree | 70b482f890c9ad2be104f0bff8f2172e8411a2be /src/text_recognizer/datasets/util.py | |
parent | fe23001b6588e6e6e9e2c5a99b72f3445cf5206f (diff) |
IAM datasets implemented.
Diffstat (limited to 'src/text_recognizer/datasets/util.py')
-rw-r--r-- | src/text_recognizer/datasets/util.py | 99 |
1 files changed, 48 insertions, 51 deletions
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index 76bd85f..dd16bed 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -1,10 +1,17 @@ """Util functions for datasets.""" +import hashlib import importlib -from typing import Callable, Dict, List, Type +import os +from pathlib import Path +from typing import Callable, Dict, List, Optional, Type, Union +from urllib.request import urlopen, urlretrieve +import cv2 +from loguru import logger import numpy as np from PIL import Image from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm class Transpose: @@ -15,58 +22,48 @@ class Transpose: return np.array(image).swapaxes(0, 1) -def fetch_data_loaders( - splits: List[str], - dataset: str, - dataset_args: Dict, - batch_size: int = 128, - shuffle: bool = False, - num_workers: int = 0, - cuda: bool = True, -) -> Dict[str, DataLoader]: - """Fetches DataLoaders for given split(s) as a dictionary. - - Loads the dataset class given, and loads it with the dataset arguments, for the number of splits specified. Then - calls the DataLoader. Added to a dictionary with the split as key and DataLoader as value. - - Args: - splits (List[str]): One or both of the dataset splits "train" and "val". - dataset (str): The name of the dataset. - dataset_args (Dict): The dataset arguments. - batch_size (int): How many samples per batch to load. Defaults to 128. - shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. - num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be - loaded in the main process. Defaults to 0. - cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning - them. Defaults to True. - - Returns: - Dict[str, DataLoader]: Dictionary with split as key and PyTorch DataLoader as value. +def compute_sha256(filename: Union[Path, str]) -> str: + """Returns the SHA256 checksum of a file.""" + with open(filename, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() - """ - - def check_dataset_args(args: Dict, split: str) -> Dict: - """Adds train flag to the dataset args.""" - args["train"] = True if split == "train" else False - return args - - # Import dataset module. - datasets_module = importlib.import_module("text_recognizer.datasets") - dataset_ = getattr(datasets_module, dataset) - data_loaders = {} +class TqdmUpTo(tqdm): + """TQDM progress bar when downloading files. - for split in ["train", "val"]: - if split in splits: + From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py - data_loader = DataLoader( - dataset=dataset_(**check_dataset_args(dataset_args, split)), - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - pin_memory=cuda, - ) - - data_loaders[split] = data_loader + """ - return data_loaders + def update_to( + self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None + ) -> None: + """Updates the progress bar. + + Args: + blocks (int): Number of blocks transferred so far. Defaults to 1. + block_size (int): Size of each block, in tqdm units. Defaults to 1. + total_size (Optional[int]): Total size in tqdm units. Defaults to None. + """ + if total_size is not None: + self.total = total_size # pylint: disable=attribute-defined-outside-init + self.update(blocks * block_size - self.n) + + +def download_url(url: str, filename: str) -> None: + """Downloads a file from url to filename, with a progress bar.""" + with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: + urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec + + +def _download_raw_dataset(metadata: Dict) -> None: + if os.path.exists(metadata["filename"]): + return + logger.info(f"Downloading raw dataset from {metadata['url']}...") + download_url(metadata["url"], metadata["filename"]) + logger.info("Computing SHA-256...") + sha256 = compute_sha256(metadata["filename"]) + if sha256 != metadata["sha256"]: + raise ValueError( + "Downloaded data file SHA-256 does not match that listed in metadata document." + ) |