summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-02 21:13:48 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-02 21:13:48 +0200
commit75801019981492eedf9280cb352eea3d8e99b65f (patch)
tree6521cc4134459e42591b2375f70acd348741474e /text_recognizer/data/emnist.py
parente5eca28438cd17d436359f2c6eee0bb9e55d2a8b (diff)
Fix log import, fix mapping in datamodules, fix nn modules can be hashed
Diffstat (limited to 'text_recognizer/data/emnist.py')
-rw-r--r--text_recognizer/data/emnist.py21
1 files changed, 10 insertions, 11 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 2d0ac29..c6be123 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -3,12 +3,12 @@ import json
import os
from pathlib import Path
import shutil
-from typing import Callable, Dict, List, Optional, Sequence, Tuple
+from typing import Callable, Dict, List, Optional, Set, Sequence, Tuple
import zipfile
import attr
import h5py
-from loguru import logger
+from loguru import logger as log
import numpy as np
import toml
import torchvision.transforms as T
@@ -50,8 +50,7 @@ class EMNIST(BaseDataModule):
transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()]))
def __attrs_post_init__(self) -> None:
- self.mapping, self.inverse_mapping, input_shape = emnist_mapping()
- self.dims = (1, *input_shape)
+ self.dims = (1, *self.mapping.input_size)
def prepare_data(self) -> None:
"""Downloads dataset if not present."""
@@ -106,7 +105,7 @@ class EMNIST(BaseDataModule):
def emnist_mapping(
- extra_symbols: Optional[Sequence[str]] = None,
+ extra_symbols: Optional[Set[str]] = None,
) -> Tuple[List, Dict[str, int], List[int]]:
"""Return the EMNIST mapping."""
if not ESSENTIALS_FILENAME.exists():
@@ -130,7 +129,7 @@ def download_and_process_emnist() -> None:
def _process_raw_dataset(filename: str, dirname: Path) -> None:
"""Processes the raw EMNIST dataset."""
- logger.info("Unzipping EMNIST...")
+ log.info("Unzipping EMNIST...")
curdir = os.getcwd()
os.chdir(dirname)
content = zipfile.ZipFile(filename, "r")
@@ -138,7 +137,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
from scipy.io import loadmat
- logger.info("Loading training data from .mat file")
+ log.info("Loading training data from .mat file")
data = loadmat("matlab/emnist-byclass.mat")
x_train = (
data["dataset"]["train"][0, 0]["images"][0, 0]
@@ -152,11 +151,11 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
if SAMPLE_TO_BALANCE:
- logger.info("Balancing classes to reduce amount of data")
+ log.info("Balancing classes to reduce amount of data")
x_train, y_train = _sample_to_balance(x_train, y_train)
x_test, y_test = _sample_to_balance(x_test, y_test)
- logger.info("Saving to HDF5 in a compressed format...")
+ log.info("Saving to HDF5 in a compressed format...")
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf")
@@ -164,7 +163,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf")
f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf")
- logger.info("Saving essential dataset parameters to text_recognizer/datasets...")
+ log.info("Saving essential dataset parameters to text_recognizer/datasets...")
mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]}
characters = _augment_emnist_characters(mapping.values())
essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])}
@@ -172,7 +171,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
with ESSENTIALS_FILENAME.open(mode="w") as f:
json.dump(essentials, f)
- logger.info("Cleaning up...")
+ log.info("Cleaning up...")
shutil.rmtree("matlab")
os.chdir(curdir)