summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/emnist_lines_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py')
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py56
1 files changed, 13 insertions, 43 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 656131a..8fa77cd 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -9,17 +9,16 @@ from loguru import logger
import numpy as np
import torch
from torch import Tensor
-from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
-from text_recognizer.datasets import (
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.emnist_dataset import EmnistDataset, Transpose
+from text_recognizer.datasets.sentence_generator import SentenceGenerator
+from text_recognizer.datasets.util import (
DATA_DIRNAME,
- EmnistDataset,
EmnistMapper,
ESSENTIALS_FILENAME,
)
-from text_recognizer.datasets.sentence_generator import SentenceGenerator
-from text_recognizer.datasets.util import Transpose
from text_recognizer.networks import sliding_window
DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
@@ -52,18 +51,11 @@ class EmnistLinesDataset(Dataset):
seed (int): Seed number. Defaults to 4711.
"""
- self.train = train
-
- self.transform = transform
- if self.transform is None:
- self.transform = ToTensor()
-
- self.target_transform = target_transform
- if self.target_transform is None:
- self.target_transform = torch.tensor
+ super().__init__(
+ train=train, transform=transform, target_transform=target_transform,
+ )
# Extract dataset information.
- self._mapper = EmnistMapper()
self._input_shape = self._mapper.input_shape
self.num_classes = self._mapper.num_classes
@@ -75,24 +67,12 @@ class EmnistLinesDataset(Dataset):
self.input_shape[0],
self.input_shape[1] * self.max_length,
)
- self.output_shape = (self.max_length, self.num_classes)
+ self._output_shape = (self.max_length, self.num_classes)
self.seed = seed
# Placeholders for the dataset.
- self.data = None
- self.target = None
-
- # Load dataset.
- self._load_or_generate_data()
-
- @property
- def input_shape(self) -> Tuple:
- """Input shape of the data."""
- return self._input_shape
-
- def __len__(self) -> int:
- """Returns the length of the dataset."""
- return len(self.data)
+ self._data = None
+ self._target = None
def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
"""Fetches data, target pair of the dataset for a given and index or indices.
@@ -132,16 +112,6 @@ class EmnistLinesDataset(Dataset):
)
@property
- def mapper(self) -> EmnistMapper:
- """Returns the EmnistMapper."""
- return self._mapper
-
- @property
- def mapping(self) -> Dict:
- """Return EMNIST mapping from index to character."""
- return self._mapper.mapping
-
- @property
def data_filename(self) -> Path:
"""Path to the h5 file."""
filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"
@@ -151,7 +121,7 @@ class EmnistLinesDataset(Dataset):
filename = "test_" + filename
return DATA_DIRNAME / filename
- def _load_or_generate_data(self) -> None:
+ def load_or_generate_data(self) -> None:
"""Loads the dataset, if it does not exist a new dataset is generated before loading it."""
np.random.seed(self.seed)
@@ -163,8 +133,8 @@ class EmnistLinesDataset(Dataset):
"""Loads the dataset from the h5 file."""
logger.debug("EmnistLinesDataset loading data from HDF5...")
with h5py.File(self.data_filename, "r") as f:
- self.data = f["data"][:]
- self.targets = f["targets"][:]
+ self._data = f["data"][:]
+ self._targets = f["targets"][:]
def _generate_data(self) -> str:
"""Generates a dataset with the Brown corpus and Emnist characters."""