summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/emnist.py')
-rw-r--r--text_recognizer/data/emnist.py22
1 files changed, 9 insertions, 13 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 824b947..d51a42a 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -3,9 +3,10 @@ import json
import os
from pathlib import Path
import shutil
-from typing import Dict, List, Optional, Sequence, Tuple
+from typing import Callable, Dict, List, Optional, Sequence, Tuple
import zipfile
+import attr
import h5py
from loguru import logger
import numpy as np
@@ -32,6 +33,7 @@ PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json"
+@attr.s(auto_attribs=True)
class EMNIST(BaseDataModule):
"""Lightning DataModule class for loading EMNIST dataset.
@@ -44,18 +46,12 @@ class EMNIST(BaseDataModule):
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
- def __init__(
- self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8
- ) -> None:
- super().__init__(batch_size, num_workers)
- self.train_fraction = train_fraction
- self.mapping, self.inverse_mapping, self.input_shape = emnist_mapping()
- self.data_train = None
- self.data_val = None
- self.data_test = None
- self.transform = T.Compose([T.ToTensor()])
- self.dims = (1, *self.input_shape)
- self.output_dims = (1,)
+ train_fraction: float = attr.ib()
+ transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()]))
+
+ def __attrs_post_init__(self) -> None:
+ self.mapping, self.inverse_mapping, input_shape = emnist_mapping()
+ self.dims = (1, *input_shape)
def prepare_data(self) -> None:
"""Downloads dataset if not present."""