summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:03:11 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:03:11 +0200
commit30e3ae483c846418b04ed48f014a4af2cf9a0771 (patch)
tree08a309e14416e68ad351be8a3d48bf50efd80d6b /text_recognizer/data/emnist.py
parentad3f404d36a9add32992698dd083d368f3b96812 (diff)
Update transforms in datamodule/set
Diffstat (limited to 'text_recognizer/data/emnist.py')
-rw-r--r--text_recognizer/data/emnist.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 9ec6efe..e2bc5b9 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -3,7 +3,7 @@ import json
import os
from pathlib import Path
import shutil
-from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple
+from typing import Dict, List, Optional, Sequence, Set, Tuple
import zipfile
import attr
@@ -11,14 +11,14 @@ import h5py
from loguru import logger as log
import numpy as np
import toml
-import torchvision.transforms as T
from text_recognizer.data.base_data_module import (
BaseDataModule,
load_and_print_info,
)
from text_recognizer.data.base_dataset import BaseDataset, split_dataset
-from text_recognizer.data.download_utils import download_dataset
+from text_recognizer.data.utils.download_utils import download_dataset
+from text_recognizer.data.transforms.load_transform import load_transform_from_file
SEED = 4711
@@ -30,7 +30,9 @@ METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist"
PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist"
PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
-ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json"
+ESSENTIALS_FILENAME = (
+ Path(__file__).parents[0].resolve() / "mappings" / "emnist_essentials.json"
+)
@attr.s(auto_attribs=True)
@@ -46,9 +48,6 @@ class EMNIST(BaseDataModule):
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
- train_fraction: float = attr.ib(default=0.8)
- transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()]))
-
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
self.dims = (1, *self.mapping.input_size)
@@ -226,4 +225,5 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
def download_emnist() -> None:
"""Download dataset from internet, if it does not exists, and displays info."""
- load_and_print_info(EMNIST)
+ transform = load_transform_from_file("transform/default.yaml")
+ load_and_print_info(EMNIST(transform=transform, test_transfrom=transform))