From e1b504bca41a9793ed7e88ef14f2e2cbd85724f2 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Tue, 8 Sep 2020 23:14:23 +0200 Subject: IAM datasets implemented. --- src/text_recognizer/datasets/util.py | 99 +++++++++++++++++------------------- 1 file changed, 48 insertions(+), 51 deletions(-) (limited to 'src/text_recognizer/datasets/util.py') diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index 76bd85f..dd16bed 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -1,10 +1,17 @@ """Util functions for datasets.""" +import hashlib import importlib -from typing import Callable, Dict, List, Type +import os +from pathlib import Path +from typing import Callable, Dict, List, Optional, Type, Union +from urllib.request import urlopen, urlretrieve +import cv2 +from loguru import logger import numpy as np from PIL import Image from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm class Transpose: @@ -15,58 +22,48 @@ class Transpose: return np.array(image).swapaxes(0, 1) -def fetch_data_loaders( - splits: List[str], - dataset: str, - dataset_args: Dict, - batch_size: int = 128, - shuffle: bool = False, - num_workers: int = 0, - cuda: bool = True, -) -> Dict[str, DataLoader]: - """Fetches DataLoaders for given split(s) as a dictionary. - - Loads the dataset class given, and loads it with the dataset arguments, for the number of splits specified. Then - calls the DataLoader. Added to a dictionary with the split as key and DataLoader as value. - - Args: - splits (List[str]): One or both of the dataset splits "train" and "val". - dataset (str): The name of the dataset. - dataset_args (Dict): The dataset arguments. - batch_size (int): How many samples per batch to load. Defaults to 128. - shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. - num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be - loaded in the main process. Defaults to 0. - cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning - them. Defaults to True. - - Returns: - Dict[str, DataLoader]: Dictionary with split as key and PyTorch DataLoader as value. +def compute_sha256(filename: Union[Path, str]) -> str: + """Returns the SHA256 checksum of a file.""" + with open(filename, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() - """ - - def check_dataset_args(args: Dict, split: str) -> Dict: - """Adds train flag to the dataset args.""" - args["train"] = True if split == "train" else False - return args - - # Import dataset module. - datasets_module = importlib.import_module("text_recognizer.datasets") - dataset_ = getattr(datasets_module, dataset) - data_loaders = {} +class TqdmUpTo(tqdm): + """TQDM progress bar when downloading files. - for split in ["train", "val"]: - if split in splits: + From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py - data_loader = DataLoader( - dataset=dataset_(**check_dataset_args(dataset_args, split)), - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - pin_memory=cuda, - ) - - data_loaders[split] = data_loader + """ - return data_loaders + def update_to( + self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None + ) -> None: + """Updates the progress bar. + + Args: + blocks (int): Number of blocks transferred so far. Defaults to 1. + block_size (int): Size of each block, in tqdm units. Defaults to 1. + total_size (Optional[int]): Total size in tqdm units. Defaults to None. + """ + if total_size is not None: + self.total = total_size # pylint: disable=attribute-defined-outside-init + self.update(blocks * block_size - self.n) + + +def download_url(url: str, filename: str) -> None: + """Downloads a file from url to filename, with a progress bar.""" + with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: + urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec + + +def _download_raw_dataset(metadata: Dict) -> None: + if os.path.exists(metadata["filename"]): + return + logger.info(f"Downloading raw dataset from {metadata['url']}...") + download_url(metadata["url"], metadata["filename"]) + logger.info("Computing SHA-256...") + sha256 = compute_sha256(metadata["filename"]) + if sha256 != metadata["sha256"]: + raise ValueError( + "Downloaded data file SHA-256 does not match that listed in metadata document." + ) -- cgit v1.2.3-70-g09d2