summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/util.py
blob: dd16bedc075e26e767488bcb964e369398625788 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Util functions for datasets."""
import hashlib
import importlib
import os
from pathlib import Path
from typing import Callable, Dict, List, Optional, Type, Union
from urllib.request import urlopen, urlretrieve

import cv2
from loguru import logger
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm


class Transpose:
    """Transposes the EMNIST image to the correct orientation."""

    def __call__(self, image: Image) -> np.ndarray:
        """Swaps axis."""
        return np.array(image).swapaxes(0, 1)


def compute_sha256(filename: Union[Path, str]) -> str:
    """Returns the SHA256 checksum of a file."""
    with open(filename, "rb") as f:
        return hashlib.sha256(f.read()).hexdigest()


class TqdmUpTo(tqdm):
    """TQDM progress bar when downloading files.

    From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py

    """

    def update_to(
        self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
    ) -> None:
        """Updates the progress bar.

        Args:
            blocks (int): Number of blocks transferred so far. Defaults to 1.
            block_size (int): Size of each block, in tqdm units. Defaults to 1.
            total_size (Optional[int]): Total size in tqdm units. Defaults to None.
        """
        if total_size is not None:
            self.total = total_size  # pylint: disable=attribute-defined-outside-init
        self.update(blocks * block_size - self.n)


def download_url(url: str, filename: str) -> None:
    """Downloads a file from url to filename, with a progress bar."""
    with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
        urlretrieve(url, filename, reporthook=t.update_to, data=None)  # nosec


def _download_raw_dataset(metadata: Dict) -> None:
    if os.path.exists(metadata["filename"]):
        return
    logger.info(f"Downloading raw dataset from {metadata['url']}...")
    download_url(metadata["url"], metadata["filename"])
    logger.info("Computing SHA-256...")
    sha256 = compute_sha256(metadata["filename"])
    if sha256 != metadata["sha256"]:
        raise ValueError(
            "Downloaded data file SHA-256 does not match that listed in metadata document."
        )