summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/iam_lines_dataset.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /src/text_recognizer/datasets/iam_lines_dataset.py
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'src/text_recognizer/datasets/iam_lines_dataset.py')
-rw-r--r--src/text_recognizer/datasets/iam_lines_dataset.py110
1 files changed, 0 insertions, 110 deletions
diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py
deleted file mode 100644
index 1cb84bd..0000000
--- a/src/text_recognizer/datasets/iam_lines_dataset.py
+++ /dev/null
@@ -1,110 +0,0 @@
-"""IamLinesDataset class."""
-from typing import Callable, Dict, List, Optional, Tuple, Union
-
-import h5py
-from loguru import logger
-import torch
-from torch import Tensor
-from torchvision.transforms import ToTensor
-
-from text_recognizer.datasets.dataset import Dataset
-from text_recognizer.datasets.util import (
- compute_sha256,
- DATA_DIRNAME,
- download_url,
- EmnistMapper,
-)
-
-
-PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines"
-PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5"
-PROCESSED_DATA_URL = (
- "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5"
-)
-
-
-class IamLinesDataset(Dataset):
- """IAM lines datasets for handwritten text lines."""
-
- def __init__(
- self,
- train: bool = False,
- subsample_fraction: float = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- init_token: Optional[str] = None,
- pad_token: Optional[str] = None,
- eos_token: Optional[str] = None,
- lower: bool = False,
- ) -> None:
- self.pad_token = "_" if pad_token is None else pad_token
-
- super().__init__(
- train=train,
- subsample_fraction=subsample_fraction,
- transform=transform,
- target_transform=target_transform,
- init_token=init_token,
- pad_token=pad_token,
- eos_token=eos_token,
- lower=lower,
- )
-
- @property
- def input_shape(self) -> Tuple:
- """Input shape of the data."""
- return self.data.shape[1:] if self.data is not None else None
-
- @property
- def output_shape(self) -> Tuple:
- """Output shape of the data."""
- return (
- self.targets.shape[1:] + (self.num_classes,)
- if self.targets is not None
- else None
- )
-
- def load_or_generate_data(self) -> None:
- """Load or generate dataset data."""
- if not PROCESSED_DATA_FILENAME.exists():
- PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
- logger.info("Downloading IAM lines...")
- download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME)
- with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
- self._data = f[f"x_{self.split}"][:]
- self._targets = f[f"y_{self.split}"][:]
- self._subsample()
-
- def __repr__(self) -> str:
- """Print info about the dataset."""
- return (
- "IAM Lines Dataset\n" # pylint: disable=no-member
- f"Number classes: {self.num_classes}\n"
- f"Mapping: {self.mapper.mapping}\n"
- f"Data: {self.data.shape}\n"
- f"Targets: {self.targets.shape}\n"
- )
-
- def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
- """Fetches data, target pair of the dataset for a given and index or indices.
-
- Args:
- index (Union[int, Tensor]): Either a list or int of indices/index.
-
- Returns:
- Tuple[Tensor, Tensor]: Data target pair.
-
- """
- if torch.is_tensor(index):
- index = index.tolist()
-
- data = self.data[index]
- targets = self.targets[index]
-
- if self.transform:
- data = self.transform(data)
-
- if self.target_transform:
- targets = self.target_transform(targets)
-
- return data, targets