diff options
Diffstat (limited to 'text_recognizer/data/base_data_module.py')
-rw-r--r-- | text_recognizer/data/base_data_module.py | 46 |
1 files changed, 26 insertions, 20 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index a0c8416..6306cf8 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -2,7 +2,6 @@ from pathlib import Path from typing import Callable, Dict, Optional, Tuple, Type, TypeVar -from attrs import define, field from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -20,29 +19,36 @@ def load_and_print_info(data_module_class: type) -> None: print(dataset) -@define(repr=False) class BaseDataModule(LightningDataModule): """Base PyTorch Lightning DataModule.""" - def __attrs_post_init__(self) -> None: - """Pre init constructor.""" + def __init__( + self, + mapping: Type[AbstractMapping], + transform: Optional[Callable] = None, + test_transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + train_fraction: float = 0.8, + batch_size: int = 16, + num_workers: int = 0, + pin_memory: bool = True, + ) -> None: super().__init__() - - mapping: Type[AbstractMapping] = field() - transform: Optional[Callable] = field(default=None) - test_transform: Optional[Callable] = field(default=None) - target_transform: Optional[Callable] = field(default=None) - train_fraction: float = field(default=0.8) - batch_size: int = field(default=16) - num_workers: int = field(default=0) - pin_memory: bool = field(default=True) - - # Placeholders - data_train: BaseDataset = field(init=False, default=None) - data_val: BaseDataset = field(init=False, default=None) - data_test: BaseDataset = field(init=False, default=None) - dims: Tuple[int, ...] = field(init=False, default=None) - output_dims: Tuple[int, ...] = field(init=False, default=None) + self.mapping = mapping + self.transform = transform + self.test_transform = test_transform + self.target_transform = target_transform + self.train_fraction = train_fraction + self.batch_size = batch_size + self.num_workers = num_workers + self.pin_memory = pin_memory + + # Placeholders + self.data_train: BaseDataset + self.data_val: BaseDataset + self.data_test: BaseDataset + self.dims: Tuple[int, ...] + self.output_dims: Tuple[int, ...] @classmethod def data_dirname(cls: T) -> Path: |