summaryrefslogtreecommitdiff
path: root/text_recognizer/data/base_data_module.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/base_data_module.py')
-rw-r--r--text_recognizer/data/base_data_module.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 16a06d9..ee70176 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,13 +1,15 @@
"""Base lightning DataModule class."""
from pathlib import Path
-from typing import Dict, Tuple, Type
+from typing import Dict, Optional, Tuple, Type, TypeVar
import attr
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
-from text_recognizer.data.base_mapping import AbstractMapping
from text_recognizer.data.base_dataset import BaseDataset
+from text_recognizer.data.base_mapping import AbstractMapping
+
+T = TypeVar("T")
def load_and_print_info(data_module_class: type) -> None:
@@ -23,6 +25,7 @@ class BaseDataModule(LightningDataModule):
"""Base PyTorch Lightning DataModule."""
def __attrs_pre_init__(self) -> None:
+ """Pre init constructor."""
super().__init__()
mapping: Type[AbstractMapping] = attr.ib()
@@ -38,7 +41,7 @@ class BaseDataModule(LightningDataModule):
output_dims: Tuple[int, ...] = attr.ib(init=False, default=None)
@classmethod
- def data_dirname(cls) -> Path:
+ def data_dirname(cls: T) -> Path:
"""Return the path to the base data directory."""
return Path(__file__).resolve().parents[2] / "data"
@@ -53,14 +56,14 @@ class BaseDataModule(LightningDataModule):
"""Prepare data for training."""
pass
- def setup(self, stage: str = None) -> None:
+ def setup(self, stage: Optional[str] = None) -> None:
"""Split into train, val, test, and set dims.
Should assign `torch Dataset` objects to self.data_train, self.data_val, and
optionally self.data_test.
Args:
- stage (Any): Variable to set splits.
+ stage (Optional[str]): Variable to set splits.
"""
pass