diff options
Diffstat (limited to 'text_recognizer/data/base_data_module.py')
-rw-r--r-- | text_recognizer/data/base_data_module.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 16a06d9..ee70176 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -1,13 +1,15 @@ """Base lightning DataModule class.""" from pathlib import Path -from typing import Dict, Tuple, Type +from typing import Dict, Optional, Tuple, Type, TypeVar import attr from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader -from text_recognizer.data.base_mapping import AbstractMapping from text_recognizer.data.base_dataset import BaseDataset +from text_recognizer.data.base_mapping import AbstractMapping + +T = TypeVar("T") def load_and_print_info(data_module_class: type) -> None: @@ -23,6 +25,7 @@ class BaseDataModule(LightningDataModule): """Base PyTorch Lightning DataModule.""" def __attrs_pre_init__(self) -> None: + """Pre init constructor.""" super().__init__() mapping: Type[AbstractMapping] = attr.ib() @@ -38,7 +41,7 @@ class BaseDataModule(LightningDataModule): output_dims: Tuple[int, ...] = attr.ib(init=False, default=None) @classmethod - def data_dirname(cls) -> Path: + def data_dirname(cls: T) -> Path: """Return the path to the base data directory.""" return Path(__file__).resolve().parents[2] / "data" @@ -53,14 +56,14 @@ class BaseDataModule(LightningDataModule): """Prepare data for training.""" pass - def setup(self, stage: str = None) -> None: + def setup(self, stage: Optional[str] = None) -> None: """Split into train, val, test, and set dims. Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test. Args: - stage (Any): Variable to set splits. + stage (Optional[str]): Variable to set splits. """ pass |