From 65df6a72b002c4b23d6f2eb545839e157f7f2aa0 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 5 Jun 2022 23:39:11 +0200 Subject: Remove attrs --- text_recognizer/data/emnist_lines.py | 49 ++++++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 13 deletions(-) (limited to 'text_recognizer/data/emnist_lines.py') diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 43d55b9..062257d 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -1,9 +1,8 @@ """Dataset of generated text from EMNIST characters.""" from collections import defaultdict from pathlib import Path -from typing import DefaultDict, List, Tuple +from typing import Callable, DefaultDict, List, Optional, Tuple, Type -from attrs import define, field import h5py from loguru import logger as log import numpy as np @@ -17,6 +16,7 @@ from text_recognizer.data.base_data_module import ( ) from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.emnist import EMNIST +from text_recognizer.data.mappings import AbstractMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils.sentence_generator import SentenceGenerator @@ -33,22 +33,45 @@ IMAGE_X_PADDING = 28 MAX_OUTPUT_LENGTH = 89 # Same as IAMLines -@define(auto_attribs=True, repr=False) class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST.""" - max_length: int = field(default=128) - min_overlap: float = field(default=0.0) - max_overlap: float = field(default=0.33) - num_train: int = field(default=10_000) - num_val: int = field(default=2_000) - num_test: int = field(default=2_000) - emnist: EMNIST = field(init=False, default=None) + def __init__( + self, + mapping: Type[AbstractMapping], + transform: Optional[Callable] = None, + test_transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + train_fraction: float = 0.8, + batch_size: int = 16, + num_workers: int = 0, + pin_memory: bool = True, + max_length: int = 128, + min_overlap: float = 0.0, + max_overlap: float = 0.33, + num_train: int = 10_000, + num_val: int = 2_000, + num_test: int = 2_000, + ) -> None: + super().__init__( + mapping, + transform, + test_transform, + target_transform, + train_fraction, + batch_size, + num_workers, + pin_memory, + ) - def __attrs_post_init__(self) -> None: - """Post init constructor.""" - self.emnist = EMNIST(mapping=self.mapping) + self.max_length = max_length + self.min_overlap = min_overlap + self.max_overlap = max_overlap + self.num_train = num_train + self.num_val = num_val + self.num_test = num_test + self.emnist = EMNIST(mapping=self.mapping) max_width = ( int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) + IMAGE_X_PADDING -- cgit v1.2.3-70-g09d2