summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/iam_lines_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/iam_lines_dataset.py')
-rw-r--r--src/text_recognizer/datasets/iam_lines_dataset.py68
1 files changed, 21 insertions, 47 deletions
diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py
index 477f500..4a74b2b 100644
--- a/src/text_recognizer/datasets/iam_lines_dataset.py
+++ b/src/text_recognizer/datasets/iam_lines_dataset.py
@@ -5,11 +5,15 @@ import h5py
from loguru import logger
import torch
from torch import Tensor
-from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
-from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper
-from text_recognizer.datasets.util import compute_sha256, download_url
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.util import (
+ compute_sha256,
+ DATA_DIRNAME,
+ download_url,
+ EmnistMapper,
+)
PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines"
@@ -29,47 +33,26 @@ class IamLinesDataset(Dataset):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
- self.train = train
- self.split = "train" if self.train else "test"
- self._mapper = EmnistMapper()
- self.num_classes = self.mapper.num_classes
-
- # Set transforms.
- self.transform = transform
- if self.transform is None:
- self.transform = ToTensor()
-
- self.target_transform = target_transform
- if self.target_transform is None:
- self.target_transform = torch.tensor
-
- self.subsample_fraction = subsample_fraction
- self.data = None
- self.targets = None
-
- @property
- def mapper(self) -> EmnistMapper:
- """Returns the EmnistMapper."""
- return self._mapper
-
- @property
- def mapping(self) -> Dict:
- """Return EMNIST mapping from index to character."""
- return self._mapper.mapping
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ )
@property
def input_shape(self) -> Tuple:
"""Input shape of the data."""
- return self.data.shape[1:]
+ return self.data.shape[1:] if self.data is not None else None
@property
def output_shape(self) -> Tuple:
"""Output shape of the data."""
- return self.targets.shape[1:] + (self.num_classes,)
-
- def __len__(self) -> int:
- """Returns the length of the dataset."""
- return len(self.data)
+ return (
+ self.targets.shape[1:] + (self.num_classes,)
+ if self.targets is not None
+ else None
+ )
def load_or_generate_data(self) -> None:
"""Load or generate dataset data."""
@@ -78,19 +61,10 @@ class IamLinesDataset(Dataset):
logger.info("Downloading IAM lines...")
download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME)
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
- self.data = f[f"x_{self.split}"][:]
- self.targets = f[f"y_{self.split}"][:]
+ self._data = f[f"x_{self.split}"][:]
+ self._targets = f[f"y_{self.split}"][:]
self._subsample()
- def _subsample(self) -> None:
- """Only a fraction of the data will be loaded."""
- if self.subsample_fraction is None:
- return
-
- num_samples = int(self.data.shape[0] * self.subsample_fraction)
- self.data = self.data[:num_samples]
- self.targets = self.targets[:num_samples]
-
def __repr__(self) -> str:
"""Print info about the dataset."""
return (