diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-24 22:15:54 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-24 22:15:54 +0100 |
commit | 8248f173132dfb7e47ec62b08e9235990c8626e3 (patch) | |
tree | 2f3ff85602cbc08b7168bf4f0d3924d32a689852 /text_recognizer/data/iam_lines_dataset.py | |
parent | 74c907a17379688967dc4b3f41a44ba83034f5e0 (diff) |
renamed datasets to data, added iam refactor
Diffstat (limited to 'text_recognizer/data/iam_lines_dataset.py')
-rw-r--r-- | text_recognizer/data/iam_lines_dataset.py | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/text_recognizer/data/iam_lines_dataset.py b/text_recognizer/data/iam_lines_dataset.py new file mode 100644 index 0000000..1cb84bd --- /dev/null +++ b/text_recognizer/data/iam_lines_dataset.py @@ -0,0 +1,110 @@ +"""IamLinesDataset class.""" +from typing import Callable, Dict, List, Optional, Tuple, Union + +import h5py +from loguru import logger +import torch +from torch import Tensor +from torchvision.transforms import ToTensor + +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.util import ( + compute_sha256, + DATA_DIRNAME, + download_url, + EmnistMapper, +) + + +PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines" +PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5" +PROCESSED_DATA_URL = ( + "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5" +) + + +class IamLinesDataset(Dataset): + """IAM lines datasets for handwritten text lines.""" + + def __init__( + self, + train: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, + lower: bool = False, + ) -> None: + self.pad_token = "_" if pad_token is None else pad_token + + super().__init__( + train=train, + subsample_fraction=subsample_fraction, + transform=transform, + target_transform=target_transform, + init_token=init_token, + pad_token=pad_token, + eos_token=eos_token, + lower=lower, + ) + + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self.data.shape[1:] if self.data is not None else None + + @property + def output_shape(self) -> Tuple: + """Output shape of the data.""" + return ( + self.targets.shape[1:] + (self.num_classes,) + if self.targets is not None + else None + ) + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + if not PROCESSED_DATA_FILENAME.exists(): + PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + logger.info("Downloading IAM lines...") + download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME) + with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: + self._data = f[f"x_{self.split}"][:] + self._targets = f[f"y_{self.split}"][:] + self._subsample() + + def __repr__(self) -> str: + """Print info about the dataset.""" + return ( + "IAM Lines Dataset\n" # pylint: disable=no-member + f"Number classes: {self.num_classes}\n" + f"Mapping: {self.mapper.mapping}\n" + f"Data: {self.data.shape}\n" + f"Targets: {self.targets.shape}\n" + ) + + def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: + """Fetches data, target pair of the dataset for a given and index or indices. + + Args: + index (Union[int, Tensor]): Either a list or int of indices/index. + + Returns: + Tuple[Tensor, Tensor]: Data target pair. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets |