summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/iam_paragraphs_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/iam_paragraphs_dataset.py')
-rw-r--r--src/text_recognizer/datasets/iam_paragraphs_dataset.py70
1 files changed, 14 insertions, 56 deletions
diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
index d65b346..4b34bd1 100644
--- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py
+++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
@@ -8,13 +8,17 @@ from loguru import logger
import numpy as np
import torch
from torch import Tensor
-from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from text_recognizer import util
-from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper
+from text_recognizer.datasets.dataset import Dataset
from text_recognizer.datasets.iam_dataset import IamDataset
-from text_recognizer.datasets.util import compute_sha256, download_url
+from text_recognizer.datasets.util import (
+ compute_sha256,
+ DATA_DIRNAME,
+ download_url,
+ EmnistMapper,
+)
INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs"
DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops"
@@ -28,11 +32,7 @@ SEED = 4711
class IamParagraphsDataset(Dataset):
- """IAM Paragraphs dataset for paragraphs of handwritten text.
-
- TODO: __getitem__, __len__, get_data_target_from_id
-
- """
+ """IAM Paragraphs dataset for paragraphs of handwritten text."""
def __init__(
self,
@@ -41,34 +41,20 @@ class IamParagraphsDataset(Dataset):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
-
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ )
# Load Iam dataset.
self.iam_dataset = IamDataset()
- self.train = train
- self.split = "train" if self.train else "test"
self.num_classes = 3
self._input_shape = (256, 256)
self._output_shape = self._input_shape + (self.num_classes,)
- self.subsample_fraction = subsample_fraction
-
- # Set transforms.
- self.transform = transform
- if self.transform is None:
- self.transform = ToTensor()
-
- self.target_transform = target_transform
- if self.target_transform is None:
- self.target_transform = torch.tensor
-
- self._data = None
- self._targets = None
self._ids = None
- def __len__(self) -> int:
- """Returns the length of the dataset."""
- return len(self.data)
-
def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
"""Fetches data, target pair of the dataset for a given and index or indices.
@@ -94,26 +80,6 @@ class IamParagraphsDataset(Dataset):
return data, targets
@property
- def input_shape(self) -> Tuple:
- """Input shape of the data."""
- return self._input_shape
-
- @property
- def output_shape(self) -> Tuple:
- """Output shape of the data."""
- return self._output_shape
-
- @property
- def data(self) -> Tensor:
- """The input data."""
- return self._data
-
- @property
- def targets(self) -> Tensor:
- """The target data."""
- return self._targets
-
- @property
def ids(self) -> Tensor:
"""Ids of the dataset."""
return self._ids
@@ -201,14 +167,6 @@ class IamParagraphsDataset(Dataset):
logger.info(f"Setting them to {max_crop_width}x{max_crop_width}")
return crop_dims
- def _subsample(self) -> None:
- """Only this fraction of the data will be loaded."""
- if self.subsample_fraction is None:
- return
- num_subsample = int(self.data.shape[0] * self.subsample_fraction)
- self.data = self.data[:num_subsample]
- self.targets = self.targets[:num_subsample]
-
def __repr__(self) -> str:
"""Return info about the dataset."""
return (