From f473456c19558aaf8552df97a51d4e18cc69dfa8 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Wed, 22 Jul 2020 23:18:08 +0200 Subject: Working training loop and testing of trained CharacterModel. --- src/text_recognizer/datasets/__init__.py | 2 ++ src/text_recognizer/datasets/emnist_dataset.py | 9 +++++---- 2 files changed, 7 insertions(+), 4 deletions(-) (limited to 'src/text_recognizer/datasets') diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index aec5bf9..795be90 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,2 +1,4 @@ """Dataset modules.""" from .emnist_dataset import EmnistDataLoader + +__all__ = ["EmnistDataLoader"] diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index a17d7a9..b92b57d 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -2,7 +2,7 @@ import json from pathlib import Path -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Type from loguru import logger import numpy as np @@ -102,21 +102,22 @@ class EmnistDataLoader: self.shuffle = shuffle self.num_workers = num_workers self.cuda = cuda + self.seed = seed self._data_loaders = self._fetch_emnist_data_loaders() @property def __name__(self) -> str: """Returns the name of the dataset.""" - return "EMNIST" + return "Emnist" - def __call__(self, split: str) -> Optional[DataLoader]: + def __call__(self, split: str) -> DataLoader: """Returns the `split` DataLoader. Args: split (str): The dataset split, i.e. train or val. Returns: - type: A PyTorch DataLoader. + DataLoader: A PyTorch DataLoader. Raises: ValueError: If the split does not exist. -- cgit v1.2.3-70-g09d2