summaryrefslogtreecommitdiff
path: root/src/text_recognizer/tests/support
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/tests/support')
-rw-r--r--src/text_recognizer/tests/support/create_emnist_support_files.py13
1 files changed, 6 insertions, 7 deletions
diff --git a/src/text_recognizer/tests/support/create_emnist_support_files.py b/src/text_recognizer/tests/support/create_emnist_support_files.py
index 5dd1a81..c04860d 100644
--- a/src/text_recognizer/tests/support/create_emnist_support_files.py
+++ b/src/text_recognizer/tests/support/create_emnist_support_files.py
@@ -2,10 +2,8 @@
from pathlib import Path
import shutil
-from text_recognizer.datasets.emnist_dataset import (
- fetch_emnist_dataset,
- load_emnist_mapping,
-)
+from text_recognizer.datasets.emnist_dataset import EmnistDataset
+from text_recognizer.datasets.util import EmnistMapper
from text_recognizer.util import write_image
SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist"
@@ -16,15 +14,16 @@ def create_emnist_support_files() -> None:
shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True)
SUPPORT_DIRNAME.mkdir()
- dataset = fetch_emnist_dataset(split="byclass", train=False)
- mapping = load_emnist_mapping()
+ dataset = EmnistDataset(train=False)
+ dataset.load_or_generate_data()
+ mapping = EmnistMapper()
for index in [5, 7, 9]:
image, label = dataset[index]
if len(image.shape) == 3:
image = image.squeeze(0)
image = image.numpy()
- label = mapping[int(label)]
+ label = mapping(int(label))
print(index, label)
write_image(image, str(SUPPORT_DIRNAME / f"{label}.png"))