diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-21 22:33:58 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-21 22:33:58 +0100 |
commit | e3741de333a3a43a7968241b6eccaaac66dd7b20 (patch) | |
tree | 7c50aee4ca61f77e95f1b038030292c64bbb86c2 /text_recognizer/datasets/emnist.py | |
parent | aac452a2dc008338cb543549652da293c14b6b4e (diff) |
Working on EMNIST Lines dataset
Diffstat (limited to 'text_recognizer/datasets/emnist.py')
-rw-r--r-- | text_recognizer/datasets/emnist.py | 88 |
1 files changed, 50 insertions, 38 deletions
diff --git a/text_recognizer/datasets/emnist.py b/text_recognizer/datasets/emnist.py index e99dbfd..7c208c4 100644 --- a/text_recognizer/datasets/emnist.py +++ b/text_recognizer/datasets/emnist.py @@ -15,20 +15,23 @@ from torch.utils.data import random_split from torchvision import transforms from text_recognizer.datasets.base_dataset import BaseDataset -from text_recognizer.datasets.base_data_module import BaseDataModule, load_print_info +from text_recognizer.datasets.base_data_module import ( + BaseDataModule, + load_and_print_info, +) from text_recognizer.datasets.download_utils import download_dataset SEED = 4711 NUM_SPECIAL_TOKENS = 4 -SAMPLE_TO_BALANCE = True +SAMPLE_TO_BALANCE = True RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist" -PROCESSED_DATA_DIRNAME = BaseDataset.data_dirname() / "processed" / "emnist" +PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist" PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" -ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnsit_essentials.json" +ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" class EMNIST(BaseDataModule): @@ -41,7 +44,9 @@ class EMNIST(BaseDataModule): EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ - def __init__(self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8) -> None: + def __init__( + self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8 + ) -> None: super().__init__(batch_size, num_workers) if not ESSENTIALS_FILENAME.exists(): _download_and_process_emnist() @@ -64,20 +69,21 @@ class EMNIST(BaseDataModule): def setup(self, stage: str = None) -> None: if stage == "fit" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - data = f["x_train"][:] - targets = f["y_train"][:] - - dataset_train = BaseDataset(data, targets, transform=self.transform) + self.x_train = f["x_train"][:] + self.y_train = f["y_train"][:] + + dataset_train = BaseDataset(self.x_train, self.y_train, transform=self.transform) train_size = int(self.train_fraction * len(dataset_train)) val_size = len(dataset_train) - train_size - self.data_train, self.data_val = random_split(dataset_train, [train_size, val_size], generator=torch.Generator()) + self.data_train, self.data_val = random_split( + dataset_train, [train_size, val_size], generator=torch.Generator() + ) if stage == "test" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - data = f["x_test"][:] - targets = f["y_test"][:] - self.data_test = BaseDataset(data, targets, transform=self.transform) - + self.x_test = f["x_test"][:] + self.y_test = f["y_test"][:] + self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform) def __repr__(self) -> str: basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n" @@ -111,9 +117,15 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: logger.info("Loading training data from .mat file") data = loadmat("matlab/emnist-byclass.mat") - x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) + x_train = ( + data["dataset"]["train"][0, 0]["images"][0, 0] + .reshape(-1, 28, 28) + .swapaxes(1, 2) + ) y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS - x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) + x_test = ( + data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) + ) y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS if SAMPLE_TO_BALANCE: @@ -121,7 +133,6 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: x_train, y_train = _sample_to_balance(x_train, y_train) x_test, y_test = _sample_to_balance(x_test, y_test) - logger.info("Saving to HDF5 in a compressed format...") PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: @@ -154,7 +165,7 @@ def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.nda all_sampled_indices.append(sampled_indices) indices = np.concatenate(all_sampled_indices) x_sampled = x[indices] - y_sampled= y[indices] + y_sampled = y[indices] return x_sampled, y_sampled @@ -162,24 +173,24 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: """Augment the mapping with extra symbols.""" # Extra characters from the IAM dataset. iam_characters = [ - " ", - "!", - '"', - "#", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "?", - ] + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] # Also add special tokens for: # - CTC blank token at index 0 @@ -190,5 +201,6 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: return ["<b>", "<s>", "</s>", "<p>", *characters, *iam_characters] -if __name__ == "__main__": - load_print_info(EMNIST) +def download_emnist() -> None: + """Download dataset from internet, if it does not exists, and displays info.""" + load_and_print_info(EMNIST) |