diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-05 23:39:11 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-05 23:39:11 +0200 |
commit | 65df6a72b002c4b23d6f2eb545839e157f7f2aa0 (patch) | |
tree | d78df1d7143dc9ff9e29afd4fd6bc7490bc79418 /text_recognizer/data/emnist_lines.py | |
parent | 8bc4b4cab00a2777a748c10fca9b3ee01e32277c (diff) |
Remove attrs
Diffstat (limited to 'text_recognizer/data/emnist_lines.py')
-rw-r--r-- | text_recognizer/data/emnist_lines.py | 49 |
1 files changed, 36 insertions, 13 deletions
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 43d55b9..062257d 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -1,9 +1,8 @@ """Dataset of generated text from EMNIST characters.""" from collections import defaultdict from pathlib import Path -from typing import DefaultDict, List, Tuple +from typing import Callable, DefaultDict, List, Optional, Tuple, Type -from attrs import define, field import h5py from loguru import logger as log import numpy as np @@ -17,6 +16,7 @@ from text_recognizer.data.base_data_module import ( ) from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.emnist import EMNIST +from text_recognizer.data.mappings import AbstractMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils.sentence_generator import SentenceGenerator @@ -33,22 +33,45 @@ IMAGE_X_PADDING = 28 MAX_OUTPUT_LENGTH = 89 # Same as IAMLines -@define(auto_attribs=True, repr=False) class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST.""" - max_length: int = field(default=128) - min_overlap: float = field(default=0.0) - max_overlap: float = field(default=0.33) - num_train: int = field(default=10_000) - num_val: int = field(default=2_000) - num_test: int = field(default=2_000) - emnist: EMNIST = field(init=False, default=None) + def __init__( + self, + mapping: Type[AbstractMapping], + transform: Optional[Callable] = None, + test_transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + train_fraction: float = 0.8, + batch_size: int = 16, + num_workers: int = 0, + pin_memory: bool = True, + max_length: int = 128, + min_overlap: float = 0.0, + max_overlap: float = 0.33, + num_train: int = 10_000, + num_val: int = 2_000, + num_test: int = 2_000, + ) -> None: + super().__init__( + mapping, + transform, + test_transform, + target_transform, + train_fraction, + batch_size, + num_workers, + pin_memory, + ) - def __attrs_post_init__(self) -> None: - """Post init constructor.""" - self.emnist = EMNIST(mapping=self.mapping) + self.max_length = max_length + self.min_overlap = min_overlap + self.max_overlap = max_overlap + self.num_train = num_train + self.num_val = num_val + self.num_test = num_test + self.emnist = EMNIST(mapping=self.mapping) max_width = ( int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) + IMAGE_X_PADDING |