summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-27 00:08:04 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-27 00:08:04 +0200
commit27ff7d113108e9cc51ddc5ff13b648b9c75fa865 (patch)
tree96b35c2f65978b8718665aaded3d29f00aaf43e2 /text_recognizer/data/emnist.py
parent3227735099f8acb37ffe658b8f04b6c308b64d23 (diff)
Add metadata
Diffstat (limited to 'text_recognizer/data/emnist.py')
-rw-r--r--text_recognizer/data/emnist.py48
1 files changed, 19 insertions, 29 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 72cc80a..9c5727f 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -15,27 +15,15 @@ from text_recognizer.data.base_data_module import BaseDataModule, load_and_print
from text_recognizer.data.base_dataset import BaseDataset, split_dataset
from text_recognizer.data.transforms.load_transform import load_transform_from_file
from text_recognizer.data.utils.download_utils import download_dataset
-
-SEED = 4711
-NUM_SPECIAL_TOKENS = 4
-SAMPLE_TO_BALANCE = True
-
-RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist"
-METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
-DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist"
-PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist"
-PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
-ESSENTIALS_FILENAME = (
- Path(__file__).parents[0].resolve() / "mappings" / "emnist_essentials.json"
-)
+from text_recognizer.metadata import emnist as metadata
class EMNIST(BaseDataModule):
"""Lightning DataModule class for loading EMNIST dataset.
'The EMNIST dataset is a set of handwritten character digits derived from the NIST
- Special Database 19 and converted to a 28x28 pixel image format and dataset structure
- that directly matches the MNIST dataset.'
+ Special Database 19 and converted to a 28x28 pixel image format and dataset
+ structure that directly matches the MNIST dataset.'
From https://www.nist.gov/itl/iad/image-group/emnist-dataset
The data split we will use is
@@ -48,13 +36,13 @@ class EMNIST(BaseDataModule):
def prepare_data(self) -> None:
"""Downloads dataset if not present."""
- if not PROCESSED_DATA_FILENAME.exists():
+ if not metadata.PROCESSED_DATA_FILENAME.exists():
download_and_process_emnist()
def setup(self, stage: Optional[str] = None) -> None:
"""Loads the dataset specified by the stage."""
if stage == "fit" or stage is None:
- with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
+ with h5py.File(metadata.PROCESSED_DATA_FILENAME, "r") as f:
self.x_train = f["x_train"][:]
self.y_train = f["y_train"][:].squeeze().astype(int)
@@ -62,11 +50,11 @@ class EMNIST(BaseDataModule):
self.x_train, self.y_train, transform=self.transform
)
self.data_train, self.data_val = split_dataset(
- dataset_train, fraction=self.train_fraction, seed=SEED
+ dataset_train, fraction=self.train_fraction, seed=metadata.SEED
)
if stage == "test" or stage is None:
- with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
+ with h5py.File(metadata.PROCESSED_DATA_FILENAME, "r") as f:
self.x_test = f["x_test"][:]
self.y_test = f["y_test"][:].squeeze().astype(int)
self.data_test = BaseDataset(
@@ -100,9 +88,9 @@ class EMNIST(BaseDataModule):
def download_and_process_emnist() -> None:
"""Downloads and preprocesses EMNIST dataset."""
- metadata = toml.load(METADATA_FILENAME)
- download_dataset(metadata, DL_DATA_DIRNAME)
- _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME)
+ metadata_ = toml.load(metadata.METADATA_FILENAME)
+ download_dataset(metadata_, metadata.DL_DATA_DIRNAME)
+ _process_raw_dataset(metadata_["filename"], metadata.DL_DATA_DIRNAME)
def _process_raw_dataset(filename: str, dirname: Path) -> None:
@@ -122,20 +110,22 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
.reshape(-1, 28, 28)
.swapaxes(1, 2)
)
- y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
+ y_train = (
+ data["dataset"]["train"][0, 0]["labels"][0, 0] + metadata.NUM_SPECIAL_TOKENS
+ )
x_test = (
data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
)
- y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
+ y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + metadata.NUM_SPECIAL_TOKENS
- if SAMPLE_TO_BALANCE:
+ if metadata.SAMPLE_TO_BALANCE:
log.info("Balancing classes to reduce amount of data")
x_train, y_train = _sample_to_balance(x_train, y_train)
x_test, y_test = _sample_to_balance(x_test, y_test)
log.info("Saving to HDF5 in a compressed format...")
- PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
- with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
+ metadata.PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ with h5py.File(metadata.PROCESSED_DATA_FILENAME, "w") as f:
f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf")
f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf")
f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf")
@@ -146,7 +136,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
characters = _augment_emnist_characters(mapping.values())
essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])}
- with ESSENTIALS_FILENAME.open(mode="w") as f:
+ with metadata.ESSENTIALS_FILENAME.open(mode="w") as f:
json.dump(essentials, f)
log.info("Cleaning up...")
@@ -156,7 +146,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Balances the dataset by taking the mean number of instances per class."""
- np.random.seed(SEED)
+ np.random.seed(metadata.SEED)
num_to_sample = int(np.bincount(y.flatten()).mean())
all_sampled_indices = []
for label in np.unique(y.flatten()):