diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-06 17:42:53 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-06 17:42:53 +0200 |
commit | eb5b206f7e1b08435378d2a02395307be55ee6f1 (patch) | |
tree | 0cd30234afab698eb632b20a7da97e3bc7e98882 /text_recognizer/data/emnist.py | |
parent | 4d1f2cef39688871d2caafce42a09316381a27ae (diff) |
Refactoring data with attrs and refactor conf for hydra
Diffstat (limited to 'text_recognizer/data/emnist.py')
-rw-r--r-- | text_recognizer/data/emnist.py | 22 |
1 files changed, 9 insertions, 13 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 824b947..d51a42a 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -3,9 +3,10 @@ import json import os from pathlib import Path import shutil -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Optional, Sequence, Tuple import zipfile +import attr import h5py from loguru import logger import numpy as np @@ -32,6 +33,7 @@ PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" +@attr.s(auto_attribs=True) class EMNIST(BaseDataModule): """Lightning DataModule class for loading EMNIST dataset. @@ -44,18 +46,12 @@ class EMNIST(BaseDataModule): EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ - def __init__( - self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8 - ) -> None: - super().__init__(batch_size, num_workers) - self.train_fraction = train_fraction - self.mapping, self.inverse_mapping, self.input_shape = emnist_mapping() - self.data_train = None - self.data_val = None - self.data_test = None - self.transform = T.Compose([T.ToTensor()]) - self.dims = (1, *self.input_shape) - self.output_dims = (1,) + train_fraction: float = attr.ib() + transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()])) + + def __attrs_post_init__(self) -> None: + self.mapping, self.inverse_mapping, input_shape = emnist_mapping() + self.dims = (1, *input_shape) def prepare_data(self) -> None: """Downloads dataset if not present.""" |