From eb5b206f7e1b08435378d2a02395307be55ee6f1 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 6 Jul 2021 17:42:53 +0200 Subject: Refactoring data with attrs and refactor conf for hydra --- text_recognizer/data/iam_extended_paragraphs.py | 33 ++++++++++++++----------- 1 file changed, 18 insertions(+), 15 deletions(-) (limited to 'text_recognizer/data/iam_extended_paragraphs.py') diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 0a30a42..886e37e 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -1,4 +1,7 @@ """IAM original and sythetic dataset class.""" +from typing import Dict, List + +import attr from torch.utils.data import ConcatDataset from text_recognizer.data.base_dataset import BaseDataset @@ -7,22 +10,26 @@ from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs +@attr.s(auto_attribs=True) class IAMExtendedParagraphs(BaseDataModule): - def __init__( - self, - batch_size: int = 16, - num_workers: int = 0, - train_fraction: float = 0.8, - augment: bool = True, - word_pieces: bool = False, - ) -> None: - super().__init__(batch_size, num_workers) + train_fraction: float = attr.ib() + word_pieces: bool = attr.ib(default=False) + + def __attrs_post_init__(self) -> None: self.iam_paragraphs = IAMParagraphs( - batch_size, num_workers, train_fraction, augment, word_pieces, + self.batch_size, + self.num_workers, + self.train_fraction, + self.augment, + self.word_pieces, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( - batch_size, num_workers, train_fraction, augment, word_pieces, + self.batch_size, + self.num_workers, + self.train_fraction, + self.augment, + self.word_pieces, ) self.dims = self.iam_paragraphs.dims @@ -30,10 +37,6 @@ class IAMExtendedParagraphs(BaseDataModule): self.mapping = self.iam_paragraphs.mapping self.inverse_mapping = self.iam_paragraphs.inverse_mapping - self.data_train: BaseDataset = None - self.data_val: BaseDataset = None - self.data_test: BaseDataset = None - def prepare_data(self) -> None: """Prepares the paragraphs data.""" self.iam_paragraphs.prepare_data() -- cgit v1.2.3-70-g09d2