diff options
Diffstat (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 129 |
1 files changed, 23 insertions, 106 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 1c6e959..d64a991 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -12,10 +12,9 @@ from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose, Normalize, ToTensor from text_recognizer.datasets import ( - _augment_emnist_mapping, - _load_emnist_essentials, DATA_DIRNAME, EmnistDataset, + EmnistMapper, ESSENTIALS_FILENAME, ) from text_recognizer.datasets.sentence_generator import SentenceGenerator @@ -30,7 +29,6 @@ class EmnistLinesDataset(Dataset): def __init__( self, train: bool = False, - emnist: Optional[EmnistDataset] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, max_length: int = 34, @@ -39,10 +37,9 @@ class EmnistLinesDataset(Dataset): num_samples: int = 10000, seed: int = 4711, ) -> None: - """Short summary. + """Set attributes and loads the dataset. Args: - emnist (EmnistDataset): A EmnistDataset object. train (bool): Flag for the filename. Defaults to False. Defaults to None. transform (Optional[Callable]): The transform of the data. Defaults to None. target_transform (Optional[Callable]): The transform of the target. Defaults to None. @@ -54,7 +51,6 @@ class EmnistLinesDataset(Dataset): """ self.train = train - self.emnist = emnist self.transform = transform if self.transform is None: @@ -64,11 +60,10 @@ class EmnistLinesDataset(Dataset): if self.target_transform is None: self.target_transform = torch.tensor - # Load emnist dataset infromation. - essentials = _load_emnist_essentials() - self.mapping = _augment_emnist_mapping(dict(essentials["mapping"])) - self.num_classes = len(self.mapping) - self.input_shape = essentials["input_shape"] + # Extract dataset information. + self._mapper = EmnistMapper() + self.input_shape = self._mapper.input_shape + self.num_classes = self._mapper.num_classes self.max_length = max_length self.min_overlap = min_overlap @@ -81,10 +76,13 @@ class EmnistLinesDataset(Dataset): self.output_shape = (self.max_length, self.num_classes) self.seed = seed - # Placeholders for the generated dataset. + # Placeholders for the dataset. self.data = None self.target = None + # Load dataset. + self._load_or_generate_data() + def __len__(self) -> int: """Returns the length of the dataset.""" return len(self.data) @@ -104,7 +102,6 @@ class EmnistLinesDataset(Dataset): if torch.is_tensor(index): index = index.tolist() - # data = np.array([self.data[index]]) data = self.data[index] targets = self.targets[index] @@ -116,6 +113,11 @@ class EmnistLinesDataset(Dataset): return data, targets + @property + def __name__(self) -> str: + """Returns the name of the dataset.""" + return "EmnistLinesDataset" + def __repr__(self) -> str: """Returns information about the dataset.""" return ( @@ -130,6 +132,11 @@ class EmnistLinesDataset(Dataset): ) @property + def mapper(self) -> EmnistMapper: + """Returns the EmnistMapper.""" + return self._mapper + + @property def data_filename(self) -> Path: """Path to the h5 file.""" filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt" @@ -161,9 +168,10 @@ class EmnistLinesDataset(Dataset): sentence_generator = SentenceGenerator(self.max_length) # Load emnist dataset. - self.emnist.load_emnist_dataset() + emnist = EmnistDataset(train=self.train, sample_to_balance=True) + samples_by_character = get_samples_by_character( - self.emnist.data.numpy(), self.emnist.targets.numpy(), self.emnist.mapping, + emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping, ) DATA_DIRNAME.mkdir(parents=True, exist_ok=True) @@ -332,94 +340,3 @@ def create_datasets( num_samples=num, ) emnist_lines._load_or_generate_data() - - -class EmnistLinesDataLoaders: - """Wrapper for a PyTorch Data loaders for the EMNIST lines dataset.""" - - def __init__( - self, - splits: List[str], - max_length: int = 34, - min_overlap: float = 0, - max_overlap: float = 0.33, - num_samples: int = 10000, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - batch_size: int = 128, - shuffle: bool = False, - num_workers: int = 0, - cuda: bool = True, - seed: int = 4711, - ) -> None: - """Sets the data loader arguments.""" - self.splits = splits - self.dataset_args = { - "max_length": max_length, - "min_overlap": min_overlap, - "max_overlap": max_overlap, - "num_samples": num_samples, - "transform": transform, - "target_transform": target_transform, - "seed": seed, - } - self.batch_size = batch_size - self.shuffle = shuffle - self.num_workers = num_workers - self.cuda = cuda - self._data_loaders = self._load_data_loaders() - - def __repr__(self) -> str: - """Returns information about the dataset.""" - return self._data_loaders[self.splits[0]].dataset.__repr__() - - @property - def __name__(self) -> str: - """Returns the name of the dataset.""" - return "EmnistLines" - - def __call__(self, split: str) -> DataLoader: - """Returns the `split` DataLoader. - - Args: - split (str): The dataset split, i.e. train or val. - - Returns: - DataLoader: A PyTorch DataLoader. - - Raises: - ValueError: If the split does not exist. - - """ - try: - return self._data_loaders[split] - except KeyError: - raise ValueError(f"Split {split} does not exist.") - - def _load_data_loaders(self) -> Dict[str, DataLoader]: - """Fetches the EMNIST Lines dataset and return a Dict of PyTorch DataLoaders.""" - data_loaders = {} - - for split in ["train", "val"]: - if split in self.splits: - - if split == "train": - self.dataset_args["train"] = True - else: - self.dataset_args["train"] = False - - emnist_lines_dataset = EmnistLinesDataset(**self.dataset_args) - - emnist_lines_dataset._load_or_generate_data() - - data_loader = DataLoader( - dataset=emnist_lines_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - pin_memory=self.cuda, - ) - - data_loaders[split] = data_loader - - return data_loaders |