summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-08-05 01:03:37 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-08-05 01:03:37 +0200
commit125d5da5fb845d03bda91426e172bca7f537584a (patch)
tree6daf305555b76338ae482e81da58aa444a8255df /src/text_recognizer
parent0d0540952f79437026fc5a146b81e4b45190ff6a (diff)
Emnist lines data loader implemented.
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/datasets/__init__.py7
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py31
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py120
3 files changed, 129 insertions, 29 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index a8c46c4..1b4cc59 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -1,21 +1,28 @@
"""Dataset modules."""
from .emnist_dataset import (
+ _augment_emnist_mapping,
+ _load_emnist_essentials,
DATA_DIRNAME,
EmnistDataLoaders,
EmnistDataset,
+ ESSENTIALS_FILENAME,
)
from .emnist_lines_dataset import (
construct_image_from_string,
+ EmnistLinesDataLoaders,
EmnistLinesDataset,
get_samples_by_character,
)
from .util import Transpose
__all__ = [
+ "_augment_emnist_mapping",
+ "_load_emnist_essentials",
"construct_image_from_string",
"DATA_DIRNAME",
"EmnistDataset",
"EmnistDataLoaders",
+ "EmnistLinesDataLoaders",
"EmnistLinesDataset",
"get_samples_by_character",
"Transpose",
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index 525df95..f3d65ee 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -260,21 +260,23 @@ class EmnistDataLoaders:
"""
self.splits = splits
- self.sample_to_balance = sample_to_balance
if subsample_fraction is not None:
if not 0.0 < subsample_fraction < 1.0:
raise ValueError("The subsample fraction must be in (0, 1).")
- self.subsample_fraction = subsample_fraction
- self.transform = transform
- self.target_transform = target_transform
+ self.dataset_args = {
+ "sample_to_balance": sample_to_balance,
+ "subsample_fraction": subsample_fraction,
+ "transform": transform,
+ "target_transform": target_transform,
+ "seed": seed,
+ }
self.batch_size = batch_size
self.shuffle = shuffle
self.num_workers = num_workers
self.cuda = cuda
- self.seed = seed
- self._data_loaders = self._fetch_emnist_data_loaders()
+ self._data_loaders = self._load_data_loaders()
def __repr__(self) -> str:
"""Returns information about the dataset."""
@@ -303,7 +305,7 @@ class EmnistDataLoaders:
except KeyError:
raise ValueError(f"Split {split} does not exist.")
- def _fetch_emnist_data_loaders(self) -> Dict[str, DataLoader]:
+ def _load_data_loaders(self) -> Dict[str, DataLoader]:
"""Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders."""
data_loaders = {}
@@ -311,18 +313,11 @@ class EmnistDataLoaders:
if split in self.splits:
if split == "train":
- train = True
+ self.dataset_args["train"] = True
else:
- train = False
-
- emnist_dataset = EmnistDataset(
- train=train,
- sample_to_balance=self.sample_to_balance,
- subsample_fraction=self.subsample_fraction,
- transform=self.transform,
- target_transform=self.target_transform,
- seed=self.seed,
- )
+ self.dataset_args["train"] = False
+
+ emnist_dataset = EmnistDataset(**self.dataset_args)
emnist_dataset.load_emnist_dataset()
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 4d8b646..1c6e959 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -8,17 +8,20 @@ import h5py
from loguru import logger
import numpy as np
import torch
-from torch.utils.data import Dataset
+from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Normalize, ToTensor
-from text_recognizer.datasets import DATA_DIRNAME, EmnistDataset
+from text_recognizer.datasets import (
+ _augment_emnist_mapping,
+ _load_emnist_essentials,
+ DATA_DIRNAME,
+ EmnistDataset,
+ ESSENTIALS_FILENAME,
+)
from text_recognizer.datasets.sentence_generator import SentenceGenerator
from text_recognizer.datasets.util import Transpose
DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
-ESSENTIALS_FILENAME = (
- Path(__file__).resolve().parents[0] / "emnist_lines_essentials.json"
-)
class EmnistLinesDataset(Dataset):
@@ -26,8 +29,8 @@ class EmnistLinesDataset(Dataset):
def __init__(
self,
- emnist: EmnistDataset,
train: bool = False,
+ emnist: Optional[EmnistDataset] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
max_length: int = 34,
@@ -40,7 +43,7 @@ class EmnistLinesDataset(Dataset):
Args:
emnist (EmnistDataset): A EmnistDataset object.
- train (bool): Flag for the filename. Defaults to False.
+ train (bool): Flag for the filename. Defaults to False. Defaults to None.
transform (Optional[Callable]): The transform of the data. Defaults to None.
target_transform (Optional[Callable]): The transform of the target. Defaults to None.
max_length (int): The maximum number of characters. Defaults to 34.
@@ -61,15 +64,19 @@ class EmnistLinesDataset(Dataset):
if self.target_transform is None:
self.target_transform = torch.tensor
- self.mapping = self.emnist.mapping
- self.num_classes = self.emnist.num_classes
+ # Load emnist dataset infromation.
+ essentials = _load_emnist_essentials()
+ self.mapping = _augment_emnist_mapping(dict(essentials["mapping"]))
+ self.num_classes = len(self.mapping)
+ self.input_shape = essentials["input_shape"]
+
self.max_length = max_length
self.min_overlap = min_overlap
self.max_overlap = max_overlap
self.num_samples = num_samples
self.input_shape = (
- self.emnist.input_shape[0],
- self.emnist.input_shape[1] * self.max_length,
+ self.input_shape[0],
+ self.input_shape[1] * self.max_length,
)
self.output_shape = (self.max_length, self.num_classes)
self.seed = seed
@@ -325,3 +332,94 @@ def create_datasets(
num_samples=num,
)
emnist_lines._load_or_generate_data()
+
+
+class EmnistLinesDataLoaders:
+ """Wrapper for a PyTorch Data loaders for the EMNIST lines dataset."""
+
+ def __init__(
+ self,
+ splits: List[str],
+ max_length: int = 34,
+ min_overlap: float = 0,
+ max_overlap: float = 0.33,
+ num_samples: int = 10000,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ batch_size: int = 128,
+ shuffle: bool = False,
+ num_workers: int = 0,
+ cuda: bool = True,
+ seed: int = 4711,
+ ) -> None:
+ """Sets the data loader arguments."""
+ self.splits = splits
+ self.dataset_args = {
+ "max_length": max_length,
+ "min_overlap": min_overlap,
+ "max_overlap": max_overlap,
+ "num_samples": num_samples,
+ "transform": transform,
+ "target_transform": target_transform,
+ "seed": seed,
+ }
+ self.batch_size = batch_size
+ self.shuffle = shuffle
+ self.num_workers = num_workers
+ self.cuda = cuda
+ self._data_loaders = self._load_data_loaders()
+
+ def __repr__(self) -> str:
+ """Returns information about the dataset."""
+ return self._data_loaders[self.splits[0]].dataset.__repr__()
+
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the dataset."""
+ return "EmnistLines"
+
+ def __call__(self, split: str) -> DataLoader:
+ """Returns the `split` DataLoader.
+
+ Args:
+ split (str): The dataset split, i.e. train or val.
+
+ Returns:
+ DataLoader: A PyTorch DataLoader.
+
+ Raises:
+ ValueError: If the split does not exist.
+
+ """
+ try:
+ return self._data_loaders[split]
+ except KeyError:
+ raise ValueError(f"Split {split} does not exist.")
+
+ def _load_data_loaders(self) -> Dict[str, DataLoader]:
+ """Fetches the EMNIST Lines dataset and return a Dict of PyTorch DataLoaders."""
+ data_loaders = {}
+
+ for split in ["train", "val"]:
+ if split in self.splits:
+
+ if split == "train":
+ self.dataset_args["train"] = True
+ else:
+ self.dataset_args["train"] = False
+
+ emnist_lines_dataset = EmnistLinesDataset(**self.dataset_args)
+
+ emnist_lines_dataset._load_or_generate_data()
+
+ data_loader = DataLoader(
+ dataset=emnist_lines_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ pin_memory=self.cuda,
+ )
+
+ data_loaders[split] = data_loader
+
+ return data_loaders