summaryrefslogtreecommitdiff
path: root/text_recognizer/datasets/emnist.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/datasets/emnist.py')
-rw-r--r--text_recognizer/datasets/emnist.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/text_recognizer/datasets/emnist.py b/text_recognizer/datasets/emnist.py
index 7c208c4..66101b5 100644
--- a/text_recognizer/datasets/emnist.py
+++ b/text_recognizer/datasets/emnist.py
@@ -70,9 +70,11 @@ class EMNIST(BaseDataModule):
if stage == "fit" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self.x_train = f["x_train"][:]
- self.y_train = f["y_train"][:]
+ self.y_train = f["y_train"][:].squeeze().astype(int)
- dataset_train = BaseDataset(self.x_train, self.y_train, transform=self.transform)
+ dataset_train = BaseDataset(
+ self.x_train, self.y_train, transform=self.transform
+ )
train_size = int(self.train_fraction * len(dataset_train))
val_size = len(dataset_train) - train_size
self.data_train, self.data_val = random_split(
@@ -82,8 +84,10 @@ class EMNIST(BaseDataModule):
if stage == "test" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self.x_test = f["x_test"][:]
- self.y_test = f["y_test"][:]
- self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform)
+ self.y_test = f["y_test"][:].squeeze().astype(int)
+ self.data_test = BaseDataset(
+ self.x_test, self.y_test, transform=self.transform
+ )
def __repr__(self) -> str:
basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n"