From eb5b206f7e1b08435378d2a02395307be55ee6f1 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 6 Jul 2021 17:42:53 +0200 Subject: Refactoring data with attrs and refactor conf for hydra --- text_recognizer/data/iam_lines.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) (limited to 'text_recognizer/data/iam_lines.py') diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index 9c78a22..e45e5c8 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -7,8 +7,9 @@ dataset. import json from pathlib import Path import random -from typing import List, Sequence, Tuple +from typing import Dict, List, Sequence, Tuple +import attr from loguru import logger from PIL import Image, ImageFile, ImageOps import numpy as np @@ -35,26 +36,17 @@ IMAGE_HEIGHT = 56 IMAGE_WIDTH = 1024 +@attr.s(auto_attribs=True) class IAMLines(BaseDataModule): """IAM handwritten lines dataset.""" - def __init__( - self, - augment: bool = True, - fraction: float = 0.8, - batch_size: int = 128, - num_workers: int = 0, - ) -> None: - # TODO: add transforms - super().__init__(batch_size, num_workers) - self.augment = augment - self.fraction = fraction + augment: bool = attr.ib(default=True) + fraction: float = attr.ib(default=0.8) + + def __attrs_post_init__(self) -> None: self.mapping, self.inverse_mapping, _ = emnist_mapping() self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) self.output_dims = (89, 1) - self.data_train: BaseDataset = None - self.data_val: BaseDataset = None - self.data_test: BaseDataset = None def prepare_data(self) -> None: """Creates the IAM lines dataset if not existing.""" -- cgit v1.2.3-70-g09d2