summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/emnist_lines_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py')
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py129
1 files changed, 23 insertions, 106 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 1c6e959..d64a991 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -12,10 +12,9 @@ from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Normalize, ToTensor
from text_recognizer.datasets import (
- _augment_emnist_mapping,
- _load_emnist_essentials,
DATA_DIRNAME,
EmnistDataset,
+ EmnistMapper,
ESSENTIALS_FILENAME,
)
from text_recognizer.datasets.sentence_generator import SentenceGenerator
@@ -30,7 +29,6 @@ class EmnistLinesDataset(Dataset):
def __init__(
self,
train: bool = False,
- emnist: Optional[EmnistDataset] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
max_length: int = 34,
@@ -39,10 +37,9 @@ class EmnistLinesDataset(Dataset):
num_samples: int = 10000,
seed: int = 4711,
) -> None:
- """Short summary.
+ """Set attributes and loads the dataset.
Args:
- emnist (EmnistDataset): A EmnistDataset object.
train (bool): Flag for the filename. Defaults to False. Defaults to None.
transform (Optional[Callable]): The transform of the data. Defaults to None.
target_transform (Optional[Callable]): The transform of the target. Defaults to None.
@@ -54,7 +51,6 @@ class EmnistLinesDataset(Dataset):
"""
self.train = train
- self.emnist = emnist
self.transform = transform
if self.transform is None:
@@ -64,11 +60,10 @@ class EmnistLinesDataset(Dataset):
if self.target_transform is None:
self.target_transform = torch.tensor
- # Load emnist dataset infromation.
- essentials = _load_emnist_essentials()
- self.mapping = _augment_emnist_mapping(dict(essentials["mapping"]))
- self.num_classes = len(self.mapping)
- self.input_shape = essentials["input_shape"]
+ # Extract dataset information.
+ self._mapper = EmnistMapper()
+ self.input_shape = self._mapper.input_shape
+ self.num_classes = self._mapper.num_classes
self.max_length = max_length
self.min_overlap = min_overlap
@@ -81,10 +76,13 @@ class EmnistLinesDataset(Dataset):
self.output_shape = (self.max_length, self.num_classes)
self.seed = seed
- # Placeholders for the generated dataset.
+ # Placeholders for the dataset.
self.data = None
self.target = None
+ # Load dataset.
+ self._load_or_generate_data()
+
def __len__(self) -> int:
"""Returns the length of the dataset."""
return len(self.data)
@@ -104,7 +102,6 @@ class EmnistLinesDataset(Dataset):
if torch.is_tensor(index):
index = index.tolist()
- # data = np.array([self.data[index]])
data = self.data[index]
targets = self.targets[index]
@@ -116,6 +113,11 @@ class EmnistLinesDataset(Dataset):
return data, targets
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the dataset."""
+ return "EmnistLinesDataset"
+
def __repr__(self) -> str:
"""Returns information about the dataset."""
return (
@@ -130,6 +132,11 @@ class EmnistLinesDataset(Dataset):
)
@property
+ def mapper(self) -> EmnistMapper:
+ """Returns the EmnistMapper."""
+ return self._mapper
+
+ @property
def data_filename(self) -> Path:
"""Path to the h5 file."""
filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"
@@ -161,9 +168,10 @@ class EmnistLinesDataset(Dataset):
sentence_generator = SentenceGenerator(self.max_length)
# Load emnist dataset.
- self.emnist.load_emnist_dataset()
+ emnist = EmnistDataset(train=self.train, sample_to_balance=True)
+
samples_by_character = get_samples_by_character(
- self.emnist.data.numpy(), self.emnist.targets.numpy(), self.emnist.mapping,
+ emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping,
)
DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
@@ -332,94 +340,3 @@ def create_datasets(
num_samples=num,
)
emnist_lines._load_or_generate_data()
-
-
-class EmnistLinesDataLoaders:
- """Wrapper for a PyTorch Data loaders for the EMNIST lines dataset."""
-
- def __init__(
- self,
- splits: List[str],
- max_length: int = 34,
- min_overlap: float = 0,
- max_overlap: float = 0.33,
- num_samples: int = 10000,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- batch_size: int = 128,
- shuffle: bool = False,
- num_workers: int = 0,
- cuda: bool = True,
- seed: int = 4711,
- ) -> None:
- """Sets the data loader arguments."""
- self.splits = splits
- self.dataset_args = {
- "max_length": max_length,
- "min_overlap": min_overlap,
- "max_overlap": max_overlap,
- "num_samples": num_samples,
- "transform": transform,
- "target_transform": target_transform,
- "seed": seed,
- }
- self.batch_size = batch_size
- self.shuffle = shuffle
- self.num_workers = num_workers
- self.cuda = cuda
- self._data_loaders = self._load_data_loaders()
-
- def __repr__(self) -> str:
- """Returns information about the dataset."""
- return self._data_loaders[self.splits[0]].dataset.__repr__()
-
- @property
- def __name__(self) -> str:
- """Returns the name of the dataset."""
- return "EmnistLines"
-
- def __call__(self, split: str) -> DataLoader:
- """Returns the `split` DataLoader.
-
- Args:
- split (str): The dataset split, i.e. train or val.
-
- Returns:
- DataLoader: A PyTorch DataLoader.
-
- Raises:
- ValueError: If the split does not exist.
-
- """
- try:
- return self._data_loaders[split]
- except KeyError:
- raise ValueError(f"Split {split} does not exist.")
-
- def _load_data_loaders(self) -> Dict[str, DataLoader]:
- """Fetches the EMNIST Lines dataset and return a Dict of PyTorch DataLoaders."""
- data_loaders = {}
-
- for split in ["train", "val"]:
- if split in self.splits:
-
- if split == "train":
- self.dataset_args["train"] = True
- else:
- self.dataset_args["train"] = False
-
- emnist_lines_dataset = EmnistLinesDataset(**self.dataset_args)
-
- emnist_lines_dataset._load_or_generate_data()
-
- data_loader = DataLoader(
- dataset=emnist_lines_dataset,
- batch_size=self.batch_size,
- shuffle=self.shuffle,
- num_workers=self.num_workers,
- pin_memory=self.cuda,
- )
-
- data_loaders[split] = data_loader
-
- return data_loaders