1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
|
"""IamLinesDataset class."""
from typing import Callable, Dict, List, Optional, Tuple, Union
import h5py
from loguru import logger
import torch
from torch import Tensor
from torchvision.transforms import ToTensor
from text_recognizer.datasets.dataset import Dataset
from text_recognizer.datasets.util import (
compute_sha256,
DATA_DIRNAME,
download_url,
EmnistMapper,
)
PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines"
PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5"
PROCESSED_DATA_URL = (
"https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5"
)
class IamLinesDataset(Dataset):
"""IAM lines datasets for handwritten text lines."""
def __init__(
self,
train: bool = False,
subsample_fraction: float = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super().__init__(
train=train,
subsample_fraction=subsample_fraction,
transform=transform,
target_transform=target_transform,
)
@property
def input_shape(self) -> Tuple:
"""Input shape of the data."""
return self.data.shape[1:] if self.data is not None else None
@property
def output_shape(self) -> Tuple:
"""Output shape of the data."""
return (
self.targets.shape[1:] + (self.num_classes,)
if self.targets is not None
else None
)
def load_or_generate_data(self) -> None:
"""Load or generate dataset data."""
if not PROCESSED_DATA_FILENAME.exists():
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
logger.info("Downloading IAM lines...")
download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME)
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self._data = f[f"x_{self.split}"][:]
self._targets = f[f"y_{self.split}"][:]
self._subsample()
def __repr__(self) -> str:
"""Print info about the dataset."""
return (
"IAM Lines Dataset\n" # pylint: disable=no-member
f"Number classes: {self.num_classes}\n"
f"Mapping: {self.mapper.mapping}\n"
f"Data: {self.data.shape}\n"
f"Targets: {self.targets.shape}\n"
)
def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
"""Fetches data, target pair of the dataset for a given and index or indices.
Args:
index (Union[int, Tensor]): Either a list or int of indices/index.
Returns:
Tuple[Tensor, Tensor]: Data target pair.
"""
if torch.is_tensor(index):
index = index.tolist()
data = self.data[index]
targets = self.targets[index]
if self.transform:
data = self.transform(data)
if self.target_transform:
targets = self.target_transform(targets)
return data, targets
|