summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/emnist_lines_dataset.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-14 22:15:47 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-14 22:15:47 +0200
commit3b06ef615a8db67a03927576e0c12fbfb2501f5f (patch)
treee1c2b1289971c8480327408de46152481e99b539 /src/text_recognizer/datasets/emnist_lines_dataset.py
parent2b63fd952bdc9c7c72edd501cbcdbf3231e98f00 (diff)
Fixed CTC loss.
Diffstat (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py')
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 8fa77cd..6268a01 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -19,7 +19,6 @@ from text_recognizer.datasets.util import (
EmnistMapper,
ESSENTIALS_FILENAME,
)
-from text_recognizer.networks import sliding_window
DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
@@ -32,6 +31,7 @@ class EmnistLinesDataset(Dataset):
train: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
+ subsample_fraction: float = None,
max_length: int = 34,
min_overlap: float = 0,
max_overlap: float = 0.33,
@@ -44,6 +44,7 @@ class EmnistLinesDataset(Dataset):
train (bool): Flag for the filename. Defaults to False. Defaults to None.
transform (Optional[Callable]): The transform of the data. Defaults to None.
target_transform (Optional[Callable]): The transform of the target. Defaults to None.
+ subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
max_length (int): The maximum number of characters. Defaults to 34.
min_overlap (float): The minimum overlap between concatenated images. Defaults to 0.
max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33.
@@ -52,7 +53,10 @@ class EmnistLinesDataset(Dataset):
"""
super().__init__(
- train=train, transform=transform, target_transform=target_transform,
+ train=train,
+ transform=transform,
+ target_transform=target_transform,
+ subsample_fraction=subsample_fraction,
)
# Extract dataset information.
@@ -128,6 +132,7 @@ class EmnistLinesDataset(Dataset):
if not self.data_filename.exists():
self._generate_data()
self._load_data()
+ self._subsample()
def _load_data(self) -> None:
"""Loads the dataset from the h5 file."""