From a3a40c9c0118039460d5c9fba6a74edc0cdba106 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 20 Mar 2021 22:52:10 +0100 Subject: add base dataset class --- text_recognizer/datasets/base_dataset.py | 70 ++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 text_recognizer/datasets/base_dataset.py (limited to 'text_recognizer/datasets/base_dataset.py') diff --git a/text_recognizer/datasets/base_dataset.py b/text_recognizer/datasets/base_dataset.py new file mode 100644 index 0000000..7322d7f --- /dev/null +++ b/text_recognizer/datasets/base_dataset.py @@ -0,0 +1,70 @@ +"""Base PyTorch Dataset class.""" +from typing import Any, Callable, Dict, Sequence, Tuple, Union + +import torch +from torch import Tensor +from torch.utils.data import Dataset + + +class BaseDataset(Dataset): + """ + Base Dataset class that processes data and targets through optional transfroms. + + Args: + data (Union[Sequence, Tensor]): Torch tensors, numpy arrays, or PIL images. + targets (Union[Sequence, Tensor]): Torch tensors or numpy arrays. + tranform (Callable): Function that takes a datum and applies transforms. + target_transform (Callable): Fucntion that takes a target and applies + target transforms. + """ + def __init__(self, + data: Union[Sequence, Tensor], + targets: Union[Sequence, Tensor], + transform: Callable = None, + target_transform: Callable = None, + ) -> None: + if len(data) != len(targets): + raise ValueError("Data and targets must be of equal length.") + self.data = data + self.targets = targets + self.transform = transform + self.target_transform = target_transform + + + def __len__(self) -> int: + """Return the length of the dataset.""" + return len(self.data) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """Return a datum and its target, after processing by transforms. + + Args: + index (int): Index of a datum in the dataset. + + Returns: + Tuple[Any, Any]: Datum and target pair. + + """ + datum, target = self.data[index], self.targets[index] + + if self.transform is not None: + datum = self.transform(datum) + + if self.target_transform is not None: + target = self.target_transform(target) + + return datum, target + + +def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> Tensor: + """ + Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with and tokens, + and padded wiht the

token. + """ + labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["

"] + for i, string in enumerate(strings): + tokens = list(string) + tokens = ["", *tokens, ""] + for j, token in enumerate(tokens): + labels[i, j] = mapping[token] + return labels -- cgit v1.2.3-70-g09d2