diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-14 22:15:47 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-14 22:15:47 +0200 |
commit | 3b06ef615a8db67a03927576e0c12fbfb2501f5f (patch) | |
tree | e1c2b1289971c8480327408de46152481e99b539 /src/text_recognizer/datasets/emnist_lines_dataset.py | |
parent | 2b63fd952bdc9c7c72edd501cbcdbf3231e98f00 (diff) |
Fixed CTC loss.
Diffstat (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 8fa77cd..6268a01 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -19,7 +19,6 @@ from text_recognizer.datasets.util import ( EmnistMapper, ESSENTIALS_FILENAME, ) -from text_recognizer.networks import sliding_window DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" @@ -32,6 +31,7 @@ class EmnistLinesDataset(Dataset): train: bool = False, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + subsample_fraction: float = None, max_length: int = 34, min_overlap: float = 0, max_overlap: float = 0.33, @@ -44,6 +44,7 @@ class EmnistLinesDataset(Dataset): train (bool): Flag for the filename. Defaults to False. Defaults to None. transform (Optional[Callable]): The transform of the data. Defaults to None. target_transform (Optional[Callable]): The transform of the target. Defaults to None. + subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. max_length (int): The maximum number of characters. Defaults to 34. min_overlap (float): The minimum overlap between concatenated images. Defaults to 0. max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33. @@ -52,7 +53,10 @@ class EmnistLinesDataset(Dataset): """ super().__init__( - train=train, transform=transform, target_transform=target_transform, + train=train, + transform=transform, + target_transform=target_transform, + subsample_fraction=subsample_fraction, ) # Extract dataset information. @@ -128,6 +132,7 @@ class EmnistLinesDataset(Dataset): if not self.data_filename.exists(): self._generate_data() self._load_data() + self._subsample() def _load_data(self) -> None: """Loads the dataset from the h5 file.""" |