diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/data/emnist.py | 40 |
2 files changed, 29 insertions, 12 deletions
diff --git a/text_recognizer/__init__.py b/text_recognizer/__init__.py index e69de29..20c123d 100644 --- a/text_recognizer/__init__.py +++ b/text_recognizer/__init__.py @@ -0,0 +1 @@ +"""Text recognizer project.""" diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 12adaab..bf3faec 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -1,22 +1,22 @@ """EMNIST dataset: downloads it from FSDL aws url if not present.""" -from pathlib import Path -from typing import Dict, List, Optional, Sequence, Tuple import json import os +from pathlib import Path import shutil +from typing import Dict, List, Optional, Sequence, Tuple import zipfile import h5py -import numpy as np from loguru import logger +import numpy as np import toml from torchvision import transforms -from text_recognizer.data.base_dataset import BaseDataset, split_dataset from text_recognizer.data.base_data_module import ( BaseDataModule, load_and_print_info, ) +from text_recognizer.data.base_dataset import BaseDataset, split_dataset from text_recognizer.data.download_utils import download_dataset @@ -33,9 +33,11 @@ ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.j class EMNIST(BaseDataModule): - """ - "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 - and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset." + """Lightning DataModule class for loading EMNIST dataset. + + 'The EMNIST dataset is a set of handwritten character digits derived from the NIST + Special Database 19 and converted to a 28x28 pixel image format and dataset structure + that directly matches the MNIST dataset.' From https://www.nist.gov/itl/iad/image-group/emnist-dataset The data split we will use is @@ -56,10 +58,12 @@ class EMNIST(BaseDataModule): self.output_dims = (1,) def prepare_data(self) -> None: + """Downloads dataset if not present.""" if not PROCESSED_DATA_FILENAME.exists(): download_and_process_emnist() def setup(self, stage: str = None) -> None: + """Loads the dataset specified by the stage.""" if stage == "fit" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: self.x_train = f["x_train"][:] @@ -81,22 +85,32 @@ class EMNIST(BaseDataModule): ) def __repr__(self) -> str: - basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n" + """Returns string with info about the dataset.""" + basic = ( + "EMNIST Dataset\n" + f"Num classes: {len(self.mapping)}\n" + f"Mapping: {self.mapping}\n" + f"Dims: {self.dims}\n" + ) if not any([self.data_train, self.data_val, self.data_test]): return basic datum, target = next(iter(self.train_dataloader())) data = ( - f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" - f"Batch x stats: {(datum.shape, datum.dtype, datum.min(), datum.mean(), datum.std(), datum.max())}\n" - f"Batch y stats: {(target.shape, target.dtype, target.min(), target.max())}\n" + "Train/val/test sizes: " + f"{len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" + "Batch x stats: " + f"{(datum.shape, datum.dtype, datum.min())}" + f"{(datum.mean(), datum.std(), datum.max())}\n" + f"Batch y stats: " + f"{(target.shape, target.dtype, target.min(), target.max())}\n" ) return basic + data def emnist_mapping( - extra_symbols: Optional[Sequence[str]], + extra_symbols: Optional[Sequence[str]] = None, ) -> Tuple[List, Dict[str, int], List[int]]: """Return the EMNIST mapping.""" if not ESSENTIALS_FILENAME.exists(): @@ -112,12 +126,14 @@ def emnist_mapping( def download_and_process_emnist() -> None: + """Downloads and preprocesses EMNIST dataset.""" metadata = toml.load(METADATA_FILENAME) download_dataset(metadata, DL_DATA_DIRNAME) _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) def _process_raw_dataset(filename: str, dirname: Path) -> None: + """Processes the raw EMNIST dataset.""" logger.info("Unzipping EMNIST...") curdir = os.getcwd() os.chdir(dirname) |