diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 00:08:04 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 00:08:04 +0200 |
commit | 27ff7d113108e9cc51ddc5ff13b648b9c75fa865 (patch) | |
tree | 96b35c2f65978b8718665aaded3d29f00aaf43e2 /text_recognizer/data/emnist.py | |
parent | 3227735099f8acb37ffe658b8f04b6c308b64d23 (diff) |
Add metadata
Diffstat (limited to 'text_recognizer/data/emnist.py')
-rw-r--r-- | text_recognizer/data/emnist.py | 48 |
1 files changed, 19 insertions, 29 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 72cc80a..9c5727f 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -15,27 +15,15 @@ from text_recognizer.data.base_data_module import BaseDataModule, load_and_print from text_recognizer.data.base_dataset import BaseDataset, split_dataset from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils.download_utils import download_dataset - -SEED = 4711 -NUM_SPECIAL_TOKENS = 4 -SAMPLE_TO_BALANCE = True - -RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist" -METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" -DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist" -PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist" -PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" -ESSENTIALS_FILENAME = ( - Path(__file__).parents[0].resolve() / "mappings" / "emnist_essentials.json" -) +from text_recognizer.metadata import emnist as metadata class EMNIST(BaseDataModule): """Lightning DataModule class for loading EMNIST dataset. 'The EMNIST dataset is a set of handwritten character digits derived from the NIST - Special Database 19 and converted to a 28x28 pixel image format and dataset structure - that directly matches the MNIST dataset.' + Special Database 19 and converted to a 28x28 pixel image format and dataset + structure that directly matches the MNIST dataset.' From https://www.nist.gov/itl/iad/image-group/emnist-dataset The data split we will use is @@ -48,13 +36,13 @@ class EMNIST(BaseDataModule): def prepare_data(self) -> None: """Downloads dataset if not present.""" - if not PROCESSED_DATA_FILENAME.exists(): + if not metadata.PROCESSED_DATA_FILENAME.exists(): download_and_process_emnist() def setup(self, stage: Optional[str] = None) -> None: """Loads the dataset specified by the stage.""" if stage == "fit" or stage is None: - with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: + with h5py.File(metadata.PROCESSED_DATA_FILENAME, "r") as f: self.x_train = f["x_train"][:] self.y_train = f["y_train"][:].squeeze().astype(int) @@ -62,11 +50,11 @@ class EMNIST(BaseDataModule): self.x_train, self.y_train, transform=self.transform ) self.data_train, self.data_val = split_dataset( - dataset_train, fraction=self.train_fraction, seed=SEED + dataset_train, fraction=self.train_fraction, seed=metadata.SEED ) if stage == "test" or stage is None: - with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: + with h5py.File(metadata.PROCESSED_DATA_FILENAME, "r") as f: self.x_test = f["x_test"][:] self.y_test = f["y_test"][:].squeeze().astype(int) self.data_test = BaseDataset( @@ -100,9 +88,9 @@ class EMNIST(BaseDataModule): def download_and_process_emnist() -> None: """Downloads and preprocesses EMNIST dataset.""" - metadata = toml.load(METADATA_FILENAME) - download_dataset(metadata, DL_DATA_DIRNAME) - _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) + metadata_ = toml.load(metadata.METADATA_FILENAME) + download_dataset(metadata_, metadata.DL_DATA_DIRNAME) + _process_raw_dataset(metadata_["filename"], metadata.DL_DATA_DIRNAME) def _process_raw_dataset(filename: str, dirname: Path) -> None: @@ -122,20 +110,22 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: .reshape(-1, 28, 28) .swapaxes(1, 2) ) - y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS + y_train = ( + data["dataset"]["train"][0, 0]["labels"][0, 0] + metadata.NUM_SPECIAL_TOKENS + ) x_test = ( data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) ) - y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS + y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + metadata.NUM_SPECIAL_TOKENS - if SAMPLE_TO_BALANCE: + if metadata.SAMPLE_TO_BALANCE: log.info("Balancing classes to reduce amount of data") x_train, y_train = _sample_to_balance(x_train, y_train) x_test, y_test = _sample_to_balance(x_test, y_test) log.info("Saving to HDF5 in a compressed format...") - PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: + metadata.PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + with h5py.File(metadata.PROCESSED_DATA_FILENAME, "w") as f: f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf") f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") @@ -146,7 +136,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: characters = _augment_emnist_characters(mapping.values()) essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} - with ESSENTIALS_FILENAME.open(mode="w") as f: + with metadata.ESSENTIALS_FILENAME.open(mode="w") as f: json.dump(essentials, f) log.info("Cleaning up...") @@ -156,7 +146,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Balances the dataset by taking the mean number of instances per class.""" - np.random.seed(SEED) + np.random.seed(metadata.SEED) num_to_sample = int(np.bincount(y.flatten()).mean()) all_sampled_indices = [] for label in np.unique(y.flatten()): |