summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/emnist_lines_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py')
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py51
1 files changed, 37 insertions, 14 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 6268a01..6871492 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -4,6 +4,7 @@ from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
+import click
import h5py
from loguru import logger
import numpy as np
@@ -37,6 +38,9 @@ class EmnistLinesDataset(Dataset):
max_overlap: float = 0.33,
num_samples: int = 10000,
seed: int = 4711,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
) -> None:
"""Set attributes and loads the dataset.
@@ -50,13 +54,21 @@ class EmnistLinesDataset(Dataset):
max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33.
num_samples (int): Number of samples to generate. Defaults to 10000.
seed (int): Seed number. Defaults to 4711.
+ init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
+ pad_token (Optional[str]): String representing the pad token. Defaults to None.
+ eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
"""
+ self.pad_token = "_" if pad_token is None else pad_token
+
super().__init__(
train=train,
transform=transform,
target_transform=target_transform,
subsample_fraction=subsample_fraction,
+ init_token=init_token,
+ pad_token=self.pad_token,
+ eos_token=eos_token,
)
# Extract dataset information.
@@ -118,11 +130,7 @@ class EmnistLinesDataset(Dataset):
@property
def data_filename(self) -> Path:
"""Path to the h5 file."""
- filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"
- if self.train:
- filename = "train_" + filename
- else:
- filename = "test_" + filename
+ filename = "train.pt" if self.train else "test.pt"
return DATA_DIRNAME / filename
def load_or_generate_data(self) -> None:
@@ -138,8 +146,8 @@ class EmnistLinesDataset(Dataset):
"""Loads the dataset from the h5 file."""
logger.debug("EmnistLinesDataset loading data from HDF5...")
with h5py.File(self.data_filename, "r") as f:
- self._data = f["data"][:]
- self._targets = f["targets"][:]
+ self._data = f["data"][()]
+ self._targets = f["targets"][()]
def _generate_data(self) -> str:
"""Generates a dataset with the Brown corpus and Emnist characters."""
@@ -148,7 +156,10 @@ class EmnistLinesDataset(Dataset):
sentence_generator = SentenceGenerator(self.max_length)
# Load emnist dataset.
- emnist = EmnistDataset(train=self.train, sample_to_balance=True)
+ emnist = EmnistDataset(
+ train=self.train, sample_to_balance=True, pad_token=self.pad_token
+ )
+ emnist.load_or_generate_data()
samples_by_character = get_samples_by_character(
emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping,
@@ -298,6 +309,18 @@ def convert_strings_to_categorical_labels(
return np.array([[mapping[c] for c in label] for label in labels])
+@click.command()
+@click.option(
+ "--max_length", type=int, default=34, help="Number of characters in a sentence."
+)
+@click.option(
+ "--min_overlap", type=float, default=0.0, help="Min overlap between characters."
+)
+@click.option(
+ "--max_overlap", type=float, default=0.33, help="Max overlap between characters."
+)
+@click.option("--num_train", type=int, default=10_000, help="Number of train examples.")
+@click.option("--num_test", type=int, default=1_000, help="Number of test examples.")
def create_datasets(
max_length: int = 34,
min_overlap: float = 0,
@@ -306,17 +329,17 @@ def create_datasets(
num_test: int = 1000,
) -> None:
"""Creates a training an validation dataset of Emnist lines."""
- emnist_train = EmnistDataset(train=True, sample_to_balance=True)
- emnist_test = EmnistDataset(train=False, sample_to_balance=True)
- datasets = [emnist_train, emnist_test]
num_samples = [num_train, num_test]
- for num, train, dataset in zip(num_samples, [True, False], datasets):
+ for num, train in zip(num_samples, [True, False]):
emnist_lines = EmnistLinesDataset(
train=train,
- emnist=dataset,
max_length=max_length,
min_overlap=min_overlap,
max_overlap=max_overlap,
num_samples=num,
)
- emnist_lines._load_or_generate_data()
+ emnist_lines.load_or_generate_data()
+
+
+if __name__ == "__main__":
+ create_datasets()