summaryrefslogtreecommitdiff
path: root/text_recognizer/data/base_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/base_dataset.py')
-rw-r--r--text_recognizer/data/base_dataset.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index a9e9c24..d00daaf 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -71,3 +71,16 @@ def convert_strings_to_labels(
for j, token in enumerate(tokens):
labels[i, j] = mapping[token]
return labels
+
+
+def split_dataset(
+ dataset: BaseDataset, fraction: float, seed: int
+) -> Tuple[BaseDataset, BaseDataset]:
+ """Split dataset into two parts with fraction * size and (1 - fraction) * size."""
+ if fraction >= 1.0:
+ raise ValueError("Fraction cannot be larger greater or equal to 1.0.")
+ split_1 = int(fraction * len(dataset))
+ split_2 = len(dataset) - split_1
+ return torch.utils.data.random_split(
+ dataset, [split_1, split_2], generator=torch.Generator().manual_seed(seed)
+ )