diff options
Diffstat (limited to 'src/text_recognizer/datasets/iam_paragraphs_dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/iam_paragraphs_dataset.py | 70 |
1 files changed, 14 insertions, 56 deletions
diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py index d65b346..4b34bd1 100644 --- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py +++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py @@ -8,13 +8,17 @@ from loguru import logger import numpy as np import torch from torch import Tensor -from torch.utils.data import Dataset from torchvision.transforms import ToTensor from text_recognizer import util -from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper +from text_recognizer.datasets.dataset import Dataset from text_recognizer.datasets.iam_dataset import IamDataset -from text_recognizer.datasets.util import compute_sha256, download_url +from text_recognizer.datasets.util import ( + compute_sha256, + DATA_DIRNAME, + download_url, + EmnistMapper, +) INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs" DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops" @@ -28,11 +32,7 @@ SEED = 4711 class IamParagraphsDataset(Dataset): - """IAM Paragraphs dataset for paragraphs of handwritten text. - - TODO: __getitem__, __len__, get_data_target_from_id - - """ + """IAM Paragraphs dataset for paragraphs of handwritten text.""" def __init__( self, @@ -41,34 +41,20 @@ class IamParagraphsDataset(Dataset): transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: - + super().__init__( + train=train, + subsample_fraction=subsample_fraction, + transform=transform, + target_transform=target_transform, + ) # Load Iam dataset. self.iam_dataset = IamDataset() - self.train = train - self.split = "train" if self.train else "test" self.num_classes = 3 self._input_shape = (256, 256) self._output_shape = self._input_shape + (self.num_classes,) - self.subsample_fraction = subsample_fraction - - # Set transforms. - self.transform = transform - if self.transform is None: - self.transform = ToTensor() - - self.target_transform = target_transform - if self.target_transform is None: - self.target_transform = torch.tensor - - self._data = None - self._targets = None self._ids = None - def __len__(self) -> int: - """Returns the length of the dataset.""" - return len(self.data) - def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: """Fetches data, target pair of the dataset for a given and index or indices. @@ -94,26 +80,6 @@ class IamParagraphsDataset(Dataset): return data, targets @property - def input_shape(self) -> Tuple: - """Input shape of the data.""" - return self._input_shape - - @property - def output_shape(self) -> Tuple: - """Output shape of the data.""" - return self._output_shape - - @property - def data(self) -> Tensor: - """The input data.""" - return self._data - - @property - def targets(self) -> Tensor: - """The target data.""" - return self._targets - - @property def ids(self) -> Tensor: """Ids of the dataset.""" return self._ids @@ -201,14 +167,6 @@ class IamParagraphsDataset(Dataset): logger.info(f"Setting them to {max_crop_width}x{max_crop_width}") return crop_dims - def _subsample(self) -> None: - """Only this fraction of the data will be loaded.""" - if self.subsample_fraction is None: - return - num_subsample = int(self.data.shape[0] * self.subsample_fraction) - self.data = self.data[:num_subsample] - self.targets = self.targets[:num_subsample] - def __repr__(self) -> str: """Return info about the dataset.""" return ( |