diff options
Diffstat (limited to 'text_recognizer/data/base_data_module.py')
-rw-r--r-- | text_recognizer/data/base_data_module.py | 32 |
1 files changed, 16 insertions, 16 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 15c286a..a0c8416 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Callable, Dict, Optional, Tuple, Type, TypeVar -import attr +from attrs import define, field from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -20,29 +20,29 @@ def load_and_print_info(data_module_class: type) -> None: print(dataset) -@attr.s(repr=False) +@define(repr=False) class BaseDataModule(LightningDataModule): """Base PyTorch Lightning DataModule.""" - def __attrs_pre_init__(self) -> None: + def __attrs_post_init__(self) -> None: """Pre init constructor.""" super().__init__() - mapping: Type[AbstractMapping] = attr.ib() - transform: Optional[Callable] = attr.ib(default=None) - test_transform: Optional[Callable] = attr.ib(default=None) - target_transform: Optional[Callable] = attr.ib(default=None) - train_fraction: float = attr.ib(default=0.8) - batch_size: int = attr.ib(default=16) - num_workers: int = attr.ib(default=0) - pin_memory: bool = attr.ib(default=True) + mapping: Type[AbstractMapping] = field() + transform: Optional[Callable] = field(default=None) + test_transform: Optional[Callable] = field(default=None) + target_transform: Optional[Callable] = field(default=None) + train_fraction: float = field(default=0.8) + batch_size: int = field(default=16) + num_workers: int = field(default=0) + pin_memory: bool = field(default=True) # Placeholders - data_train: BaseDataset = attr.ib(init=False, default=None) - data_val: BaseDataset = attr.ib(init=False, default=None) - data_test: BaseDataset = attr.ib(init=False, default=None) - dims: Tuple[int, ...] = attr.ib(init=False, default=None) - output_dims: Tuple[int, ...] = attr.ib(init=False, default=None) + data_train: BaseDataset = field(init=False, default=None) + data_val: BaseDataset = field(init=False, default=None) + data_test: BaseDataset = field(init=False, default=None) + dims: Tuple[int, ...] = field(init=False, default=None) + output_dims: Tuple[int, ...] = field(init=False, default=None) @classmethod def data_dirname(cls: T) -> Path: |