summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/emnist.py40
1 files changed, 28 insertions, 12 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 12adaab..bf3faec 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -1,22 +1,22 @@
"""EMNIST dataset: downloads it from FSDL aws url if not present."""
-from pathlib import Path
-from typing import Dict, List, Optional, Sequence, Tuple
import json
import os
+from pathlib import Path
import shutil
+from typing import Dict, List, Optional, Sequence, Tuple
import zipfile
import h5py
-import numpy as np
from loguru import logger
+import numpy as np
import toml
from torchvision import transforms
-from text_recognizer.data.base_dataset import BaseDataset, split_dataset
from text_recognizer.data.base_data_module import (
BaseDataModule,
load_and_print_info,
)
+from text_recognizer.data.base_dataset import BaseDataset, split_dataset
from text_recognizer.data.download_utils import download_dataset
@@ -33,9 +33,11 @@ ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.j
class EMNIST(BaseDataModule):
- """
- "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19
- and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset."
+ """Lightning DataModule class for loading EMNIST dataset.
+
+ 'The EMNIST dataset is a set of handwritten character digits derived from the NIST
+ Special Database 19 and converted to a 28x28 pixel image format and dataset structure
+ that directly matches the MNIST dataset.'
From https://www.nist.gov/itl/iad/image-group/emnist-dataset
The data split we will use is
@@ -56,10 +58,12 @@ class EMNIST(BaseDataModule):
self.output_dims = (1,)
def prepare_data(self) -> None:
+ """Downloads dataset if not present."""
if not PROCESSED_DATA_FILENAME.exists():
download_and_process_emnist()
def setup(self, stage: str = None) -> None:
+ """Loads the dataset specified by the stage."""
if stage == "fit" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self.x_train = f["x_train"][:]
@@ -81,22 +85,32 @@ class EMNIST(BaseDataModule):
)
def __repr__(self) -> str:
- basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n"
+ """Returns string with info about the dataset."""
+ basic = (
+ "EMNIST Dataset\n"
+ f"Num classes: {len(self.mapping)}\n"
+ f"Mapping: {self.mapping}\n"
+ f"Dims: {self.dims}\n"
+ )
if not any([self.data_train, self.data_val, self.data_test]):
return basic
datum, target = next(iter(self.train_dataloader()))
data = (
- f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
- f"Batch x stats: {(datum.shape, datum.dtype, datum.min(), datum.mean(), datum.std(), datum.max())}\n"
- f"Batch y stats: {(target.shape, target.dtype, target.min(), target.max())}\n"
+ "Train/val/test sizes: "
+ f"{len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
+ "Batch x stats: "
+ f"{(datum.shape, datum.dtype, datum.min())}"
+ f"{(datum.mean(), datum.std(), datum.max())}\n"
+ f"Batch y stats: "
+ f"{(target.shape, target.dtype, target.min(), target.max())}\n"
)
return basic + data
def emnist_mapping(
- extra_symbols: Optional[Sequence[str]],
+ extra_symbols: Optional[Sequence[str]] = None,
) -> Tuple[List, Dict[str, int], List[int]]:
"""Return the EMNIST mapping."""
if not ESSENTIALS_FILENAME.exists():
@@ -112,12 +126,14 @@ def emnist_mapping(
def download_and_process_emnist() -> None:
+ """Downloads and preprocesses EMNIST dataset."""
metadata = toml.load(METADATA_FILENAME)
download_dataset(metadata, DL_DATA_DIRNAME)
_process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME)
def _process_raw_dataset(filename: str, dirname: Path) -> None:
+ """Processes the raw EMNIST dataset."""
logger.info("Unzipping EMNIST...")
curdir = os.getcwd()
os.chdir(dirname)