diff options
Diffstat (limited to 'text_recognizer/data/emnist.py')
-rw-r--r-- | text_recognizer/data/emnist.py | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 2d0ac29..c6be123 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -3,12 +3,12 @@ import json import os from pathlib import Path import shutil -from typing import Callable, Dict, List, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Optional, Set, Sequence, Tuple import zipfile import attr import h5py -from loguru import logger +from loguru import logger as log import numpy as np import toml import torchvision.transforms as T @@ -50,8 +50,7 @@ class EMNIST(BaseDataModule): transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()])) def __attrs_post_init__(self) -> None: - self.mapping, self.inverse_mapping, input_shape = emnist_mapping() - self.dims = (1, *input_shape) + self.dims = (1, *self.mapping.input_size) def prepare_data(self) -> None: """Downloads dataset if not present.""" @@ -106,7 +105,7 @@ class EMNIST(BaseDataModule): def emnist_mapping( - extra_symbols: Optional[Sequence[str]] = None, + extra_symbols: Optional[Set[str]] = None, ) -> Tuple[List, Dict[str, int], List[int]]: """Return the EMNIST mapping.""" if not ESSENTIALS_FILENAME.exists(): @@ -130,7 +129,7 @@ def download_and_process_emnist() -> None: def _process_raw_dataset(filename: str, dirname: Path) -> None: """Processes the raw EMNIST dataset.""" - logger.info("Unzipping EMNIST...") + log.info("Unzipping EMNIST...") curdir = os.getcwd() os.chdir(dirname) content = zipfile.ZipFile(filename, "r") @@ -138,7 +137,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: from scipy.io import loadmat - logger.info("Loading training data from .mat file") + log.info("Loading training data from .mat file") data = loadmat("matlab/emnist-byclass.mat") x_train = ( data["dataset"]["train"][0, 0]["images"][0, 0] @@ -152,11 +151,11 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS if SAMPLE_TO_BALANCE: - logger.info("Balancing classes to reduce amount of data") + log.info("Balancing classes to reduce amount of data") x_train, y_train = _sample_to_balance(x_train, y_train) x_test, y_test = _sample_to_balance(x_test, y_test) - logger.info("Saving to HDF5 in a compressed format...") + log.info("Saving to HDF5 in a compressed format...") PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") @@ -164,7 +163,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf") - logger.info("Saving essential dataset parameters to text_recognizer/datasets...") + log.info("Saving essential dataset parameters to text_recognizer/datasets...") mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]} characters = _augment_emnist_characters(mapping.values()) essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} @@ -172,7 +171,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: with ESSENTIALS_FILENAME.open(mode="w") as f: json.dump(essentials, f) - logger.info("Cleaning up...") + log.info("Cleaning up...") shutil.rmtree("matlab") os.chdir(curdir) |