diff options
Diffstat (limited to 'text_recognizer/data/base_data_module.py')
-rw-r--r-- | text_recognizer/data/base_data_module.py | 29 |
1 files changed, 16 insertions, 13 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index de5628f..18b1996 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -1,11 +1,13 @@ """Base lightning DataModule class.""" from pathlib import Path -from typing import Dict +from typing import Any, Dict, Tuple import attr -import pytorch_lightning as LightningDataModule +from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader +from text_recognizer.data.base_dataset import BaseDataset + def load_and_print_info(data_module_class: type) -> None: """Load dataset and print dataset information.""" @@ -19,17 +21,20 @@ def load_and_print_info(data_module_class: type) -> None: class BaseDataModule(LightningDataModule): """Base PyTorch Lightning DataModule.""" - batch_size: int = attr.ib(default=16) - num_workers: int = attr.ib(default=0) - def __attrs_pre_init__(self) -> None: super().__init__() - def __attrs_post_init__(self) -> None: - # Placeholders for subclasses. - self.dims = None - self.output_dims = None - self.mapping = None + batch_size: int = attr.ib(default=16) + num_workers: int = attr.ib(default=0) + + # Placeholders + data_train: BaseDataset = attr.ib(init=False, default=None) + data_val: BaseDataset = attr.ib(init=False, default=None) + data_test: BaseDataset = attr.ib(init=False, default=None) + dims: Tuple[int, ...] = attr.ib(init=False, default=None) + output_dims: Tuple[int, ...] = attr.ib(init=False, default=None) + mapping: Any = attr.ib(init=False, default=None) + inverse_mapping: Dict[str, int] = attr.ib(init=False) @classmethod def data_dirname(cls) -> Path: @@ -58,9 +63,7 @@ class BaseDataModule(LightningDataModule): stage (Any): Variable to set splits. """ - self.data_train = None - self.data_val = None - self.data_test = None + pass def train_dataloader(self) -> DataLoader: """Retun DataLoader for train data.""" |