summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/emnist.py')
-rw-r--r--text_recognizer/data/emnist.py14
1 files changed, 5 insertions, 9 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index e551543..72cc80a 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -1,25 +1,21 @@
"""EMNIST dataset: downloads it from FSDL aws url if not present."""
import json
import os
-from pathlib import Path
import shutil
-from typing import Sequence, Tuple
import zipfile
+from pathlib import Path
+from typing import Optional, Sequence, Tuple
import h5py
-from loguru import logger as log
import numpy as np
import toml
+from loguru import logger as log
-from text_recognizer.data.base_data_module import (
- BaseDataModule,
- load_and_print_info,
-)
+from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import BaseDataset, split_dataset
from text_recognizer.data.transforms.load_transform import load_transform_from_file
from text_recognizer.data.utils.download_utils import download_dataset
-
SEED = 4711
NUM_SPECIAL_TOKENS = 4
SAMPLE_TO_BALANCE = True
@@ -55,7 +51,7 @@ class EMNIST(BaseDataModule):
if not PROCESSED_DATA_FILENAME.exists():
download_and_process_emnist()
- def setup(self, stage: str = None) -> None:
+ def setup(self, stage: Optional[str] = None) -> None:
"""Loads the dataset specified by the stage."""
if stage == "fit" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: