From eb5b206f7e1b08435378d2a02395307be55ee6f1 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 6 Jul 2021 17:42:53 +0200 Subject: Refactoring data with attrs and refactor conf for hydra --- text_recognizer/data/emnist_lines.py | 35 +++++++++++------------------------ 1 file changed, 11 insertions(+), 24 deletions(-) (limited to 'text_recognizer/data/emnist_lines.py') diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 9650198..4747508 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -3,6 +3,7 @@ from collections import defaultdict from pathlib import Path from typing import Callable, Dict, Tuple +import attr import h5py from loguru import logger import numpy as np @@ -31,31 +32,20 @@ IMAGE_X_PADDING = 28 MAX_OUTPUT_LENGTH = 89 # Same as IAMLines +@attr.s(auto_attribs=True) class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,""" - def __init__( - self, - augment: bool = True, - batch_size: int = 128, - num_workers: int = 0, - max_length: int = 32, - min_overlap: float = 0.0, - max_overlap: float = 0.33, - num_train: int = 10_000, - num_val: int = 2_000, - num_test: int = 2_000, - ) -> None: - super().__init__(batch_size, num_workers) - - self.augment = augment - self.max_length = max_length - self.min_overlap = min_overlap - self.max_overlap = max_overlap - self.num_train = num_train - self.num_val = num_val - self.num_test = num_test + augment: bool = attr.ib(default=True) + max_length: int = attr.ib(default=128) + min_overlap: float = attr.ib(default=0.0) + max_overlap: float = attr.ib(default=0.33) + num_train: int = attr.ib(default=10_000) + num_val: int = attr.ib(default=2_000) + num_test: int = attr.ib(default=2_000) + emnist: EMNIST = attr.ib(init=False, default=None) + def __attrs_post_init__(self) -> None: self.emnist = EMNIST() self.mapping = self.emnist.mapping @@ -75,9 +65,6 @@ class EMNISTLines(BaseDataModule): raise ValueError("max_length greater than MAX_OUTPUT_LENGTH") self.output_dims = (MAX_OUTPUT_LENGTH, 1) - self.data_train: BaseDataset = None - self.data_val: BaseDataset = None - self.data_test: BaseDataset = None @property def data_filename(self) -> Path: -- cgit v1.2.3-70-g09d2