diff options
Diffstat (limited to 'text_recognizer/data/emnist.py')
-rw-r--r-- | text_recognizer/data/emnist.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 9ec6efe..e2bc5b9 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -3,7 +3,7 @@ import json import os from pathlib import Path import shutil -from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple +from typing import Dict, List, Optional, Sequence, Set, Tuple import zipfile import attr @@ -11,14 +11,14 @@ import h5py from loguru import logger as log import numpy as np import toml -import torchvision.transforms as T from text_recognizer.data.base_data_module import ( BaseDataModule, load_and_print_info, ) from text_recognizer.data.base_dataset import BaseDataset, split_dataset -from text_recognizer.data.download_utils import download_dataset +from text_recognizer.data.utils.download_utils import download_dataset +from text_recognizer.data.transforms.load_transform import load_transform_from_file SEED = 4711 @@ -30,7 +30,9 @@ METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist" PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist" PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" -ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" +ESSENTIALS_FILENAME = ( + Path(__file__).parents[0].resolve() / "mappings" / "emnist_essentials.json" +) @attr.s(auto_attribs=True) @@ -46,9 +48,6 @@ class EMNIST(BaseDataModule): EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ - train_fraction: float = attr.ib(default=0.8) - transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()])) - def __attrs_post_init__(self) -> None: """Post init configuration.""" self.dims = (1, *self.mapping.input_size) @@ -226,4 +225,5 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: def download_emnist() -> None: """Download dataset from internet, if it does not exists, and displays info.""" - load_and_print_info(EMNIST) + transform = load_transform_from_file("transform/default.yaml") + load_and_print_info(EMNIST(transform=transform, test_transfrom=transform)) |