summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_lines_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_lines_dataset.py')
-rw-r--r--text_recognizer/data/iam_lines_dataset.py110
1 files changed, 110 insertions, 0 deletions
diff --git a/text_recognizer/data/iam_lines_dataset.py b/text_recognizer/data/iam_lines_dataset.py
new file mode 100644
index 0000000..1cb84bd
--- /dev/null
+++ b/text_recognizer/data/iam_lines_dataset.py
@@ -0,0 +1,110 @@
+"""IamLinesDataset class."""
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import h5py
+from loguru import logger
+import torch
+from torch import Tensor
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.util import (
+ compute_sha256,
+ DATA_DIRNAME,
+ download_url,
+ EmnistMapper,
+)
+
+
+PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines"
+PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5"
+PROCESSED_DATA_URL = (
+ "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5"
+)
+
+
+class IamLinesDataset(Dataset):
+ """IAM lines datasets for handwritten text lines."""
+
+ def __init__(
+ self,
+ train: bool = False,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ lower: bool = False,
+ ) -> None:
+ self.pad_token = "_" if pad_token is None else pad_token
+
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ init_token=init_token,
+ pad_token=pad_token,
+ eos_token=eos_token,
+ lower=lower,
+ )
+
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self.data.shape[1:] if self.data is not None else None
+
+ @property
+ def output_shape(self) -> Tuple:
+ """Output shape of the data."""
+ return (
+ self.targets.shape[1:] + (self.num_classes,)
+ if self.targets is not None
+ else None
+ )
+
+ def load_or_generate_data(self) -> None:
+ """Load or generate dataset data."""
+ if not PROCESSED_DATA_FILENAME.exists():
+ PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ logger.info("Downloading IAM lines...")
+ download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME)
+ with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
+ self._data = f[f"x_{self.split}"][:]
+ self._targets = f[f"y_{self.split}"][:]
+ self._subsample()
+
+ def __repr__(self) -> str:
+ """Print info about the dataset."""
+ return (
+ "IAM Lines Dataset\n" # pylint: disable=no-member
+ f"Number classes: {self.num_classes}\n"
+ f"Mapping: {self.mapper.mapping}\n"
+ f"Data: {self.data.shape}\n"
+ f"Targets: {self.targets.shape}\n"
+ )
+
+ def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
+ """Fetches data, target pair of the dataset for a given and index or indices.
+
+ Args:
+ index (Union[int, Tensor]): Either a list or int of indices/index.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Data target pair.
+
+ """
+ if torch.is_tensor(index):
+ index = index.tolist()
+
+ data = self.data[index]
+ targets = self.targets[index]
+
+ if self.transform:
+ data = self.transform(data)
+
+ if self.target_transform:
+ targets = self.target_transform(targets)
+
+ return data, targets