summaryrefslogtreecommitdiff
path: root/text_recognizer/datasets/emnist.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-23 21:55:42 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-23 21:55:42 +0100
commitae589fb3ffdbf6c4bb1ae35345f7a3665deeebc5 (patch)
tree1702f74c069679ebdd74a03892275c6eb3a80ffd /text_recognizer/datasets/emnist.py
parente3741de333a3a43a7968241b6eccaaac66dd7b20 (diff)
refactored emnist lines dataset
Diffstat (limited to 'text_recognizer/datasets/emnist.py')
-rw-r--r--text_recognizer/datasets/emnist.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/text_recognizer/datasets/emnist.py b/text_recognizer/datasets/emnist.py
index 7c208c4..66101b5 100644
--- a/text_recognizer/datasets/emnist.py
+++ b/text_recognizer/datasets/emnist.py
@@ -70,9 +70,11 @@ class EMNIST(BaseDataModule):
if stage == "fit" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self.x_train = f["x_train"][:]
- self.y_train = f["y_train"][:]
+ self.y_train = f["y_train"][:].squeeze().astype(int)
- dataset_train = BaseDataset(self.x_train, self.y_train, transform=self.transform)
+ dataset_train = BaseDataset(
+ self.x_train, self.y_train, transform=self.transform
+ )
train_size = int(self.train_fraction * len(dataset_train))
val_size = len(dataset_train) - train_size
self.data_train, self.data_val = random_split(
@@ -82,8 +84,10 @@ class EMNIST(BaseDataModule):
if stage == "test" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self.x_test = f["x_test"][:]
- self.y_test = f["y_test"][:]
- self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform)
+ self.y_test = f["y_test"][:].squeeze().astype(int)
+ self.data_test = BaseDataset(
+ self.x_test, self.y_test, transform=self.transform
+ )
def __repr__(self) -> str:
basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n"