diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:03:11 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:03:11 +0200 |
commit | 30e3ae483c846418b04ed48f014a4af2cf9a0771 (patch) | |
tree | 08a309e14416e68ad351be8a3d48bf50efd80d6b /text_recognizer/data/emnist.py | |
parent | ad3f404d36a9add32992698dd083d368f3b96812 (diff) |
Update transforms in datamodule/set
Diffstat (limited to 'text_recognizer/data/emnist.py')
-rw-r--r-- | text_recognizer/data/emnist.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 9ec6efe..e2bc5b9 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -3,7 +3,7 @@ import json import os from pathlib import Path import shutil -from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple +from typing import Dict, List, Optional, Sequence, Set, Tuple import zipfile import attr @@ -11,14 +11,14 @@ import h5py from loguru import logger as log import numpy as np import toml -import torchvision.transforms as T from text_recognizer.data.base_data_module import ( BaseDataModule, load_and_print_info, ) from text_recognizer.data.base_dataset import BaseDataset, split_dataset -from text_recognizer.data.download_utils import download_dataset +from text_recognizer.data.utils.download_utils import download_dataset +from text_recognizer.data.transforms.load_transform import load_transform_from_file SEED = 4711 @@ -30,7 +30,9 @@ METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist" PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist" PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" -ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" +ESSENTIALS_FILENAME = ( + Path(__file__).parents[0].resolve() / "mappings" / "emnist_essentials.json" +) @attr.s(auto_attribs=True) @@ -46,9 +48,6 @@ class EMNIST(BaseDataModule): EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ - train_fraction: float = attr.ib(default=0.8) - transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()])) - def __attrs_post_init__(self) -> None: """Post init configuration.""" self.dims = (1, *self.mapping.input_size) @@ -226,4 +225,5 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: def download_emnist() -> None: """Download dataset from internet, if it does not exists, and displays info.""" - load_and_print_info(EMNIST) + transform = load_transform_from_file("transform/default.yaml") + load_and_print_info(EMNIST(transform=transform, test_transfrom=transform)) |