diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-28 22:02:24 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-28 22:02:24 +0200 |
commit | 46a1472d33d3a4180798492e819f2ec02bc3b1a3 (patch) | |
tree | 22322ed0d8f9f803966ea745ec5bb8c759f8db64 /text_recognizer/data/emnist.py | |
parent | 8248f173132dfb7e47ec62b08e9235990c8626e3 (diff) |
Add refactor of iam lines
Diffstat (limited to 'text_recognizer/data/emnist.py')
-rw-r--r-- | text_recognizer/data/emnist.py | 37 |
1 files changed, 20 insertions, 17 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 7f67893..3e10b5f 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -1,6 +1,6 @@ """EMNIST dataset: downloads it from FSDL aws url if not present.""" from pathlib import Path -from typing import Sequence, Tuple +from typing import Dict, List, Sequence, Tuple import json import os import shutil @@ -10,11 +10,9 @@ import h5py import numpy as np from loguru import logger import toml -import torch -from torch.utils.data import random_split from torchvision import transforms -from text_recognizer.data.base_dataset import BaseDataset +from text_recognizer.data.base_dataset import BaseDataset, split_dataset from text_recognizer.data.base_data_module import ( BaseDataModule, load_and_print_info, @@ -48,23 +46,18 @@ class EMNIST(BaseDataModule): self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8 ) -> None: super().__init__(batch_size, num_workers) - if not ESSENTIALS_FILENAME.exists(): - _download_and_process_emnist() - with ESSENTIALS_FILENAME.open() as f: - essentials = json.load(f) self.train_fraction = train_fraction - self.mapping = list(essentials["characters"]) - self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} + self.mapping, self.inverse_mapping, self.input_shape = emnist_mapping() self.data_train = None self.data_val = None self.data_test = None self.transform = transforms.Compose([transforms.ToTensor()]) - self.dims = (1, *essentials["input_shape"]) + self.dims = (1, * self.input_shape) self.output_dims = (1,) def prepare_data(self) -> None: if not PROCESSED_DATA_FILENAME.exists(): - _download_and_process_emnist() + download_and_process_emnist() def setup(self, stage: str = None) -> None: if stage == "fit" or stage is None: @@ -75,10 +68,8 @@ class EMNIST(BaseDataModule): dataset_train = BaseDataset( self.x_train, self.y_train, transform=self.transform ) - train_size = int(self.train_fraction * len(dataset_train)) - val_size = len(dataset_train) - train_size - self.data_train, self.data_val = random_split( - dataset_train, [train_size, val_size], generator=torch.Generator() + self.data_train, self.data_val = split_dataset( + dataset_train, fraction=self.train_fraction, seed=SEED ) if stage == "test" or stage is None: @@ -104,7 +95,19 @@ class EMNIST(BaseDataModule): return basic + data -def _download_and_process_emnist() -> None: +def emnist_mapping() -> Tuple[List, Dict[str, int], List[int]]: + """Return the EMNIST mapping.""" + if not ESSENTIALS_FILENAME.exists(): + download_and_process_emnist() + with ESSENTIALS_FILENAME.open() as f: + essentials = json.load(f) + mapping = list(essentials["characters"]) + inverse_mapping = {v: k for k, v in enumerate(mapping)} + input_shape = essentials["input_shape"] + return mapping, inverse_mapping, input_shape + + +def download_and_process_emnist() -> None: metadata = toml.load(METADATA_FILENAME) download_dataset(metadata, DL_DATA_DIRNAME) _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) |