summaryrefslogtreecommitdiff
path: root/text_recognizer/datasets/emnist.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/datasets/emnist.py')
-rw-r--r--text_recognizer/datasets/emnist.py88
1 files changed, 50 insertions, 38 deletions
diff --git a/text_recognizer/datasets/emnist.py b/text_recognizer/datasets/emnist.py
index e99dbfd..7c208c4 100644
--- a/text_recognizer/datasets/emnist.py
+++ b/text_recognizer/datasets/emnist.py
@@ -15,20 +15,23 @@ from torch.utils.data import random_split
from torchvision import transforms
from text_recognizer.datasets.base_dataset import BaseDataset
-from text_recognizer.datasets.base_data_module import BaseDataModule, load_print_info
+from text_recognizer.datasets.base_data_module import (
+ BaseDataModule,
+ load_and_print_info,
+)
from text_recognizer.datasets.download_utils import download_dataset
SEED = 4711
NUM_SPECIAL_TOKENS = 4
-SAMPLE_TO_BALANCE = True
+SAMPLE_TO_BALANCE = True
RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist"
-PROCESSED_DATA_DIRNAME = BaseDataset.data_dirname() / "processed" / "emnist"
+PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist"
PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
-ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnsit_essentials.json"
+ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json"
class EMNIST(BaseDataModule):
@@ -41,7 +44,9 @@ class EMNIST(BaseDataModule):
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
- def __init__(self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8) -> None:
+ def __init__(
+ self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8
+ ) -> None:
super().__init__(batch_size, num_workers)
if not ESSENTIALS_FILENAME.exists():
_download_and_process_emnist()
@@ -64,20 +69,21 @@ class EMNIST(BaseDataModule):
def setup(self, stage: str = None) -> None:
if stage == "fit" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
- data = f["x_train"][:]
- targets = f["y_train"][:]
-
- dataset_train = BaseDataset(data, targets, transform=self.transform)
+ self.x_train = f["x_train"][:]
+ self.y_train = f["y_train"][:]
+
+ dataset_train = BaseDataset(self.x_train, self.y_train, transform=self.transform)
train_size = int(self.train_fraction * len(dataset_train))
val_size = len(dataset_train) - train_size
- self.data_train, self.data_val = random_split(dataset_train, [train_size, val_size], generator=torch.Generator())
+ self.data_train, self.data_val = random_split(
+ dataset_train, [train_size, val_size], generator=torch.Generator()
+ )
if stage == "test" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
- data = f["x_test"][:]
- targets = f["y_test"][:]
- self.data_test = BaseDataset(data, targets, transform=self.transform)
-
+ self.x_test = f["x_test"][:]
+ self.y_test = f["y_test"][:]
+ self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform)
def __repr__(self) -> str:
basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n"
@@ -111,9 +117,15 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
logger.info("Loading training data from .mat file")
data = loadmat("matlab/emnist-byclass.mat")
- x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
+ x_train = (
+ data["dataset"]["train"][0, 0]["images"][0, 0]
+ .reshape(-1, 28, 28)
+ .swapaxes(1, 2)
+ )
y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
- x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
+ x_test = (
+ data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
+ )
y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
if SAMPLE_TO_BALANCE:
@@ -121,7 +133,6 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
x_train, y_train = _sample_to_balance(x_train, y_train)
x_test, y_test = _sample_to_balance(x_test, y_test)
-
logger.info("Saving to HDF5 in a compressed format...")
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
@@ -154,7 +165,7 @@ def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.nda
all_sampled_indices.append(sampled_indices)
indices = np.concatenate(all_sampled_indices)
x_sampled = x[indices]
- y_sampled= y[indices]
+ y_sampled = y[indices]
return x_sampled, y_sampled
@@ -162,24 +173,24 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
"""Augment the mapping with extra symbols."""
# Extra characters from the IAM dataset.
iam_characters = [
- " ",
- "!",
- '"',
- "#",
- "&",
- "'",
- "(",
- ")",
- "*",
- "+",
- ",",
- "-",
- ".",
- "/",
- ":",
- ";",
- "?",
- ]
+ " ",
+ "!",
+ '"',
+ "#",
+ "&",
+ "'",
+ "(",
+ ")",
+ "*",
+ "+",
+ ",",
+ "-",
+ ".",
+ "/",
+ ":",
+ ";",
+ "?",
+ ]
# Also add special tokens for:
# - CTC blank token at index 0
@@ -190,5 +201,6 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
return ["<b>", "<s>", "</s>", "<p>", *characters, *iam_characters]
-if __name__ == "__main__":
- load_print_info(EMNIST)
+def download_emnist() -> None:
+ """Download dataset from internet, if it does not exists, and displays info."""
+ load_and_print_info(EMNIST)