diff options
Diffstat (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py')
| -rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 129 | 
1 files changed, 23 insertions, 106 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 1c6e959..d64a991 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -12,10 +12,9 @@ from torch.utils.data import DataLoader, Dataset  from torchvision.transforms import Compose, Normalize, ToTensor  from text_recognizer.datasets import ( -    _augment_emnist_mapping, -    _load_emnist_essentials,      DATA_DIRNAME,      EmnistDataset, +    EmnistMapper,      ESSENTIALS_FILENAME,  )  from text_recognizer.datasets.sentence_generator import SentenceGenerator @@ -30,7 +29,6 @@ class EmnistLinesDataset(Dataset):      def __init__(          self,          train: bool = False, -        emnist: Optional[EmnistDataset] = None,          transform: Optional[Callable] = None,          target_transform: Optional[Callable] = None,          max_length: int = 34, @@ -39,10 +37,9 @@ class EmnistLinesDataset(Dataset):          num_samples: int = 10000,          seed: int = 4711,      ) -> None: -        """Short summary. +        """Set attributes and loads the dataset.          Args: -            emnist (EmnistDataset): A EmnistDataset object.              train (bool): Flag for the filename. Defaults to False. Defaults to None.              transform (Optional[Callable]): The transform of the data. Defaults to None.              target_transform (Optional[Callable]): The transform of the target. Defaults to None. @@ -54,7 +51,6 @@ class EmnistLinesDataset(Dataset):          """          self.train = train -        self.emnist = emnist          self.transform = transform          if self.transform is None: @@ -64,11 +60,10 @@ class EmnistLinesDataset(Dataset):          if self.target_transform is None:              self.target_transform = torch.tensor -        # Load emnist dataset infromation. -        essentials = _load_emnist_essentials() -        self.mapping = _augment_emnist_mapping(dict(essentials["mapping"])) -        self.num_classes = len(self.mapping) -        self.input_shape = essentials["input_shape"] +        # Extract dataset information. +        self._mapper = EmnistMapper() +        self.input_shape = self._mapper.input_shape +        self.num_classes = self._mapper.num_classes          self.max_length = max_length          self.min_overlap = min_overlap @@ -81,10 +76,13 @@ class EmnistLinesDataset(Dataset):          self.output_shape = (self.max_length, self.num_classes)          self.seed = seed -        # Placeholders for the generated dataset. +        # Placeholders for the dataset.          self.data = None          self.target = None +        # Load dataset. +        self._load_or_generate_data() +      def __len__(self) -> int:          """Returns the length of the dataset."""          return len(self.data) @@ -104,7 +102,6 @@ class EmnistLinesDataset(Dataset):          if torch.is_tensor(index):              index = index.tolist() -        # data = np.array([self.data[index]])          data = self.data[index]          targets = self.targets[index] @@ -116,6 +113,11 @@ class EmnistLinesDataset(Dataset):          return data, targets +    @property +    def __name__(self) -> str: +        """Returns the name of the dataset.""" +        return "EmnistLinesDataset" +      def __repr__(self) -> str:          """Returns information about the dataset."""          return ( @@ -130,6 +132,11 @@ class EmnistLinesDataset(Dataset):          )      @property +    def mapper(self) -> EmnistMapper: +        """Returns the EmnistMapper.""" +        return self._mapper + +    @property      def data_filename(self) -> Path:          """Path to the h5 file."""          filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt" @@ -161,9 +168,10 @@ class EmnistLinesDataset(Dataset):          sentence_generator = SentenceGenerator(self.max_length)          # Load emnist dataset. -        self.emnist.load_emnist_dataset() +        emnist = EmnistDataset(train=self.train, sample_to_balance=True) +          samples_by_character = get_samples_by_character( -            self.emnist.data.numpy(), self.emnist.targets.numpy(), self.emnist.mapping, +            emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping,          )          DATA_DIRNAME.mkdir(parents=True, exist_ok=True) @@ -332,94 +340,3 @@ def create_datasets(              num_samples=num,          )          emnist_lines._load_or_generate_data() - - -class EmnistLinesDataLoaders: -    """Wrapper for a PyTorch Data loaders for the EMNIST lines dataset.""" - -    def __init__( -        self, -        splits: List[str], -        max_length: int = 34, -        min_overlap: float = 0, -        max_overlap: float = 0.33, -        num_samples: int = 10000, -        transform: Optional[Callable] = None, -        target_transform: Optional[Callable] = None, -        batch_size: int = 128, -        shuffle: bool = False, -        num_workers: int = 0, -        cuda: bool = True, -        seed: int = 4711, -    ) -> None: -        """Sets the data loader arguments.""" -        self.splits = splits -        self.dataset_args = { -            "max_length": max_length, -            "min_overlap": min_overlap, -            "max_overlap": max_overlap, -            "num_samples": num_samples, -            "transform": transform, -            "target_transform": target_transform, -            "seed": seed, -        } -        self.batch_size = batch_size -        self.shuffle = shuffle -        self.num_workers = num_workers -        self.cuda = cuda -        self._data_loaders = self._load_data_loaders() - -    def __repr__(self) -> str: -        """Returns information about the dataset.""" -        return self._data_loaders[self.splits[0]].dataset.__repr__() - -    @property -    def __name__(self) -> str: -        """Returns the name of the dataset.""" -        return "EmnistLines" - -    def __call__(self, split: str) -> DataLoader: -        """Returns the `split` DataLoader. - -        Args: -            split (str): The dataset split, i.e. train or val. - -        Returns: -            DataLoader: A PyTorch DataLoader. - -        Raises: -            ValueError: If the split does not exist. - -        """ -        try: -            return self._data_loaders[split] -        except KeyError: -            raise ValueError(f"Split {split} does not exist.") - -    def _load_data_loaders(self) -> Dict[str, DataLoader]: -        """Fetches the EMNIST Lines dataset and return a Dict of PyTorch DataLoaders.""" -        data_loaders = {} - -        for split in ["train", "val"]: -            if split in self.splits: - -                if split == "train": -                    self.dataset_args["train"] = True -                else: -                    self.dataset_args["train"] = False - -                emnist_lines_dataset = EmnistLinesDataset(**self.dataset_args) - -                emnist_lines_dataset._load_or_generate_data() - -                data_loader = DataLoader( -                    dataset=emnist_lines_dataset, -                    batch_size=self.batch_size, -                    shuffle=self.shuffle, -                    num_workers=self.num_workers, -                    pin_memory=self.cuda, -                ) - -                data_loaders[split] = data_loader - -        return data_loaders  |