summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/iam_lines_dataset.py
blob: 4a74b2b5650f101002a5a9bd81809a828e649d33 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""IamLinesDataset class."""
from typing import Callable, Dict, List, Optional, Tuple, Union

import h5py
from loguru import logger
import torch
from torch import Tensor
from torchvision.transforms import ToTensor

from text_recognizer.datasets.dataset import Dataset
from text_recognizer.datasets.util import (
    compute_sha256,
    DATA_DIRNAME,
    download_url,
    EmnistMapper,
)


PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines"
PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5"
PROCESSED_DATA_URL = (
    "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5"
)


class IamLinesDataset(Dataset):
    """IAM lines datasets for handwritten text lines."""

    def __init__(
        self,
        train: bool = False,
        subsample_fraction: float = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        super().__init__(
            train=train,
            subsample_fraction=subsample_fraction,
            transform=transform,
            target_transform=target_transform,
        )

    @property
    def input_shape(self) -> Tuple:
        """Input shape of the data."""
        return self.data.shape[1:] if self.data is not None else None

    @property
    def output_shape(self) -> Tuple:
        """Output shape of the data."""
        return (
            self.targets.shape[1:] + (self.num_classes,)
            if self.targets is not None
            else None
        )

    def load_or_generate_data(self) -> None:
        """Load or generate dataset data."""
        if not PROCESSED_DATA_FILENAME.exists():
            PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
            logger.info("Downloading IAM lines...")
            download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME)
        with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
            self._data = f[f"x_{self.split}"][:]
            self._targets = f[f"y_{self.split}"][:]
        self._subsample()

    def __repr__(self) -> str:
        """Print info about the dataset."""
        return (
            "IAM Lines Dataset\n"  # pylint: disable=no-member
            f"Number classes: {self.num_classes}\n"
            f"Mapping: {self.mapper.mapping}\n"
            f"Data: {self.data.shape}\n"
            f"Targets: {self.targets.shape}\n"
        )

    def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
        """Fetches data, target pair of the dataset for a given and index or indices.

        Args:
            index (Union[int, Tensor]): Either a list or int of indices/index.

        Returns:
            Tuple[Tensor, Tensor]: Data target pair.

        """
        if torch.is_tensor(index):
            index = index.tolist()

        data = self.data[index]
        targets = self.targets[index]

        if self.transform:
            data = self.transform(data)

        if self.target_transform:
            targets = self.target_transform(targets)

        return data, targets