diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-23 21:55:42 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-23 21:55:42 +0100 |
commit | ae589fb3ffdbf6c4bb1ae35345f7a3665deeebc5 (patch) | |
tree | 1702f74c069679ebdd74a03892275c6eb3a80ffd /text_recognizer/datasets/emnist.py | |
parent | e3741de333a3a43a7968241b6eccaaac66dd7b20 (diff) |
refactored emnist lines dataset
Diffstat (limited to 'text_recognizer/datasets/emnist.py')
-rw-r--r-- | text_recognizer/datasets/emnist.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/text_recognizer/datasets/emnist.py b/text_recognizer/datasets/emnist.py index 7c208c4..66101b5 100644 --- a/text_recognizer/datasets/emnist.py +++ b/text_recognizer/datasets/emnist.py @@ -70,9 +70,11 @@ class EMNIST(BaseDataModule): if stage == "fit" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: self.x_train = f["x_train"][:] - self.y_train = f["y_train"][:] + self.y_train = f["y_train"][:].squeeze().astype(int) - dataset_train = BaseDataset(self.x_train, self.y_train, transform=self.transform) + dataset_train = BaseDataset( + self.x_train, self.y_train, transform=self.transform + ) train_size = int(self.train_fraction * len(dataset_train)) val_size = len(dataset_train) - train_size self.data_train, self.data_val = random_split( @@ -82,8 +84,10 @@ class EMNIST(BaseDataModule): if stage == "test" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: self.x_test = f["x_test"][:] - self.y_test = f["y_test"][:] - self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform) + self.y_test = f["y_test"][:].squeeze().astype(int) + self.data_test = BaseDataset( + self.x_test, self.y_test, transform=self.transform + ) def __repr__(self) -> str: basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n" |