diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-28 22:02:24 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-28 22:02:24 +0200 |
commit | 46a1472d33d3a4180798492e819f2ec02bc3b1a3 (patch) | |
tree | 22322ed0d8f9f803966ea745ec5bb8c759f8db64 /text_recognizer/data/base_dataset.py | |
parent | 8248f173132dfb7e47ec62b08e9235990c8626e3 (diff) |
Add refactor of iam lines
Diffstat (limited to 'text_recognizer/data/base_dataset.py')
-rw-r--r-- | text_recognizer/data/base_dataset.py | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index a9e9c24..d00daaf 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -71,3 +71,16 @@ def convert_strings_to_labels( for j, token in enumerate(tokens): labels[i, j] = mapping[token] return labels + + +def split_dataset( + dataset: BaseDataset, fraction: float, seed: int +) -> Tuple[BaseDataset, BaseDataset]: + """Split dataset into two parts with fraction * size and (1 - fraction) * size.""" + if fraction >= 1.0: + raise ValueError("Fraction cannot be larger greater or equal to 1.0.") + split_1 = int(fraction * len(dataset)) + split_2 = len(dataset) - split_1 + return torch.utils.data.random_split( + dataset, [split_1, split_2], generator=torch.Generator().manual_seed(seed) + ) |