summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/emnist.py')
-rw-r--r--text_recognizer/data/emnist.py37
1 files changed, 20 insertions, 17 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 7f67893..3e10b5f 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -1,6 +1,6 @@
"""EMNIST dataset: downloads it from FSDL aws url if not present."""
from pathlib import Path
-from typing import Sequence, Tuple
+from typing import Dict, List, Sequence, Tuple
import json
import os
import shutil
@@ -10,11 +10,9 @@ import h5py
import numpy as np
from loguru import logger
import toml
-import torch
-from torch.utils.data import random_split
from torchvision import transforms
-from text_recognizer.data.base_dataset import BaseDataset
+from text_recognizer.data.base_dataset import BaseDataset, split_dataset
from text_recognizer.data.base_data_module import (
BaseDataModule,
load_and_print_info,
@@ -48,23 +46,18 @@ class EMNIST(BaseDataModule):
self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8
) -> None:
super().__init__(batch_size, num_workers)
- if not ESSENTIALS_FILENAME.exists():
- _download_and_process_emnist()
- with ESSENTIALS_FILENAME.open() as f:
- essentials = json.load(f)
self.train_fraction = train_fraction
- self.mapping = list(essentials["characters"])
- self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)}
+ self.mapping, self.inverse_mapping, self.input_shape = emnist_mapping()
self.data_train = None
self.data_val = None
self.data_test = None
self.transform = transforms.Compose([transforms.ToTensor()])
- self.dims = (1, *essentials["input_shape"])
+ self.dims = (1, * self.input_shape)
self.output_dims = (1,)
def prepare_data(self) -> None:
if not PROCESSED_DATA_FILENAME.exists():
- _download_and_process_emnist()
+ download_and_process_emnist()
def setup(self, stage: str = None) -> None:
if stage == "fit" or stage is None:
@@ -75,10 +68,8 @@ class EMNIST(BaseDataModule):
dataset_train = BaseDataset(
self.x_train, self.y_train, transform=self.transform
)
- train_size = int(self.train_fraction * len(dataset_train))
- val_size = len(dataset_train) - train_size
- self.data_train, self.data_val = random_split(
- dataset_train, [train_size, val_size], generator=torch.Generator()
+ self.data_train, self.data_val = split_dataset(
+ dataset_train, fraction=self.train_fraction, seed=SEED
)
if stage == "test" or stage is None:
@@ -104,7 +95,19 @@ class EMNIST(BaseDataModule):
return basic + data
-def _download_and_process_emnist() -> None:
+def emnist_mapping() -> Tuple[List, Dict[str, int], List[int]]:
+ """Return the EMNIST mapping."""
+ if not ESSENTIALS_FILENAME.exists():
+ download_and_process_emnist()
+ with ESSENTIALS_FILENAME.open() as f:
+ essentials = json.load(f)
+ mapping = list(essentials["characters"])
+ inverse_mapping = {v: k for k, v in enumerate(mapping)}
+ input_shape = essentials["input_shape"]
+ return mapping, inverse_mapping, input_shape
+
+
+def download_and_process_emnist() -> None:
metadata = toml.load(METADATA_FILENAME)
download_dataset(metadata, DL_DATA_DIRNAME)
_process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME)