summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/util.py')
-rw-r--r--src/text_recognizer/datasets/util.py99
1 files changed, 48 insertions, 51 deletions
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index 76bd85f..dd16bed 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -1,10 +1,17 @@
"""Util functions for datasets."""
+import hashlib
import importlib
-from typing import Callable, Dict, List, Type
+import os
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Type, Union
+from urllib.request import urlopen, urlretrieve
+import cv2
+from loguru import logger
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
class Transpose:
@@ -15,58 +22,48 @@ class Transpose:
return np.array(image).swapaxes(0, 1)
-def fetch_data_loaders(
- splits: List[str],
- dataset: str,
- dataset_args: Dict,
- batch_size: int = 128,
- shuffle: bool = False,
- num_workers: int = 0,
- cuda: bool = True,
-) -> Dict[str, DataLoader]:
- """Fetches DataLoaders for given split(s) as a dictionary.
-
- Loads the dataset class given, and loads it with the dataset arguments, for the number of splits specified. Then
- calls the DataLoader. Added to a dictionary with the split as key and DataLoader as value.
-
- Args:
- splits (List[str]): One or both of the dataset splits "train" and "val".
- dataset (str): The name of the dataset.
- dataset_args (Dict): The dataset arguments.
- batch_size (int): How many samples per batch to load. Defaults to 128.
- shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False.
- num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be
- loaded in the main process. Defaults to 0.
- cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning
- them. Defaults to True.
-
- Returns:
- Dict[str, DataLoader]: Dictionary with split as key and PyTorch DataLoader as value.
+def compute_sha256(filename: Union[Path, str]) -> str:
+ """Returns the SHA256 checksum of a file."""
+ with open(filename, "rb") as f:
+ return hashlib.sha256(f.read()).hexdigest()
- """
-
- def check_dataset_args(args: Dict, split: str) -> Dict:
- """Adds train flag to the dataset args."""
- args["train"] = True if split == "train" else False
- return args
-
- # Import dataset module.
- datasets_module = importlib.import_module("text_recognizer.datasets")
- dataset_ = getattr(datasets_module, dataset)
- data_loaders = {}
+class TqdmUpTo(tqdm):
+ """TQDM progress bar when downloading files.
- for split in ["train", "val"]:
- if split in splits:
+ From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
- data_loader = DataLoader(
- dataset=dataset_(**check_dataset_args(dataset_args, split)),
- batch_size=batch_size,
- shuffle=shuffle,
- num_workers=num_workers,
- pin_memory=cuda,
- )
-
- data_loaders[split] = data_loader
+ """
- return data_loaders
+ def update_to(
+ self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
+ ) -> None:
+ """Updates the progress bar.
+
+ Args:
+ blocks (int): Number of blocks transferred so far. Defaults to 1.
+ block_size (int): Size of each block, in tqdm units. Defaults to 1.
+ total_size (Optional[int]): Total size in tqdm units. Defaults to None.
+ """
+ if total_size is not None:
+ self.total = total_size # pylint: disable=attribute-defined-outside-init
+ self.update(blocks * block_size - self.n)
+
+
+def download_url(url: str, filename: str) -> None:
+ """Downloads a file from url to filename, with a progress bar."""
+ with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
+ urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec
+
+
+def _download_raw_dataset(metadata: Dict) -> None:
+ if os.path.exists(metadata["filename"]):
+ return
+ logger.info(f"Downloading raw dataset from {metadata['url']}...")
+ download_url(metadata["url"], metadata["filename"])
+ logger.info("Computing SHA-256...")
+ sha256 = compute_sha256(metadata["filename"])
+ if sha256 != metadata["sha256"]:
+ raise ValueError(
+ "Downloaded data file SHA-256 does not match that listed in metadata document."
+ )