summaryrefslogtreecommitdiff
path: root/text_recognizer/datasets
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 22:52:10 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 22:52:10 +0100
commita3a40c9c0118039460d5c9fba6a74edc0cdba106 (patch)
tree8d7a421e444302f2fea0a420220dac85dedd2135 /text_recognizer/datasets
parent7e8e54e84c63171e748bbf09516fd517e6821ace (diff)
add base dataset class
Diffstat (limited to 'text_recognizer/datasets')
-rw-r--r--text_recognizer/datasets/base_dataset.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/text_recognizer/datasets/base_dataset.py b/text_recognizer/datasets/base_dataset.py
new file mode 100644
index 0000000..7322d7f
--- /dev/null
+++ b/text_recognizer/datasets/base_dataset.py
@@ -0,0 +1,70 @@
+"""Base PyTorch Dataset class."""
+from typing import Any, Callable, Dict, Sequence, Tuple, Union
+
+import torch
+from torch import Tensor
+from torch.utils.data import Dataset
+
+
+class BaseDataset(Dataset):
+ """
+ Base Dataset class that processes data and targets through optional transfroms.
+
+ Args:
+ data (Union[Sequence, Tensor]): Torch tensors, numpy arrays, or PIL images.
+ targets (Union[Sequence, Tensor]): Torch tensors or numpy arrays.
+ tranform (Callable): Function that takes a datum and applies transforms.
+ target_transform (Callable): Fucntion that takes a target and applies
+ target transforms.
+ """
+ def __init__(self,
+ data: Union[Sequence, Tensor],
+ targets: Union[Sequence, Tensor],
+ transform: Callable = None,
+ target_transform: Callable = None,
+ ) -> None:
+ if len(data) != len(targets):
+ raise ValueError("Data and targets must be of equal length.")
+ self.data = data
+ self.targets = targets
+ self.transform = transform
+ self.target_transform = target_transform
+
+
+ def __len__(self) -> int:
+ """Return the length of the dataset."""
+ return len(self.data)
+
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
+ """Return a datum and its target, after processing by transforms.
+
+ Args:
+ index (int): Index of a datum in the dataset.
+
+ Returns:
+ Tuple[Any, Any]: Datum and target pair.
+
+ """
+ datum, target = self.data[index], self.targets[index]
+
+ if self.transform is not None:
+ datum = self.transform(datum)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return datum, target
+
+
+def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> Tensor:
+ """
+ Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <S> and </S> tokens,
+ and padded wiht the <P> token.
+ """
+ labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<P>"]
+ for i, string in enumerate(strings):
+ tokens = list(string)
+ tokens = ["<S>", *tokens, "</S>"]
+ for j, token in enumerate(tokens):
+ labels[i, j] = mapping[token]
+ return labels