diff options
Diffstat (limited to 'src/text_recognizer/tests/support')
-rw-r--r-- | src/text_recognizer/tests/support/create_emnist_support_files.py | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/src/text_recognizer/tests/support/create_emnist_support_files.py b/src/text_recognizer/tests/support/create_emnist_support_files.py index 5dd1a81..c04860d 100644 --- a/src/text_recognizer/tests/support/create_emnist_support_files.py +++ b/src/text_recognizer/tests/support/create_emnist_support_files.py @@ -2,10 +2,8 @@ from pathlib import Path import shutil -from text_recognizer.datasets.emnist_dataset import ( - fetch_emnist_dataset, - load_emnist_mapping, -) +from text_recognizer.datasets.emnist_dataset import EmnistDataset +from text_recognizer.datasets.util import EmnistMapper from text_recognizer.util import write_image SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist" @@ -16,15 +14,16 @@ def create_emnist_support_files() -> None: shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True) SUPPORT_DIRNAME.mkdir() - dataset = fetch_emnist_dataset(split="byclass", train=False) - mapping = load_emnist_mapping() + dataset = EmnistDataset(train=False) + dataset.load_or_generate_data() + mapping = EmnistMapper() for index in [5, 7, 9]: image, label = dataset[index] if len(image.shape) == 3: image = image.squeeze(0) image = image.numpy() - label = mapping[int(label)] + label = mapping(int(label)) print(index, label) write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) |