diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-05 23:39:11 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-05 23:39:11 +0200 |
commit | 65df6a72b002c4b23d6f2eb545839e157f7f2aa0 (patch) | |
tree | d78df1d7143dc9ff9e29afd4fd6bc7490bc79418 /text_recognizer/data/base_data_module.py | |
parent | 8bc4b4cab00a2777a748c10fca9b3ee01e32277c (diff) |
Remove attrs
Diffstat (limited to 'text_recognizer/data/base_data_module.py')
-rw-r--r-- | text_recognizer/data/base_data_module.py | 46 |
1 files changed, 26 insertions, 20 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index a0c8416..6306cf8 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -2,7 +2,6 @@ from pathlib import Path from typing import Callable, Dict, Optional, Tuple, Type, TypeVar -from attrs import define, field from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -20,29 +19,36 @@ def load_and_print_info(data_module_class: type) -> None: print(dataset) -@define(repr=False) class BaseDataModule(LightningDataModule): """Base PyTorch Lightning DataModule.""" - def __attrs_post_init__(self) -> None: - """Pre init constructor.""" + def __init__( + self, + mapping: Type[AbstractMapping], + transform: Optional[Callable] = None, + test_transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + train_fraction: float = 0.8, + batch_size: int = 16, + num_workers: int = 0, + pin_memory: bool = True, + ) -> None: super().__init__() - - mapping: Type[AbstractMapping] = field() - transform: Optional[Callable] = field(default=None) - test_transform: Optional[Callable] = field(default=None) - target_transform: Optional[Callable] = field(default=None) - train_fraction: float = field(default=0.8) - batch_size: int = field(default=16) - num_workers: int = field(default=0) - pin_memory: bool = field(default=True) - - # Placeholders - data_train: BaseDataset = field(init=False, default=None) - data_val: BaseDataset = field(init=False, default=None) - data_test: BaseDataset = field(init=False, default=None) - dims: Tuple[int, ...] = field(init=False, default=None) - output_dims: Tuple[int, ...] = field(init=False, default=None) + self.mapping = mapping + self.transform = transform + self.test_transform = test_transform + self.target_transform = target_transform + self.train_fraction = train_fraction + self.batch_size = batch_size + self.num_workers = num_workers + self.pin_memory = pin_memory + + # Placeholders + self.data_train: BaseDataset + self.data_val: BaseDataset + self.data_test: BaseDataset + self.dims: Tuple[int, ...] + self.output_dims: Tuple[int, ...] @classmethod def data_dirname(cls: T) -> Path: |