summaryrefslogtreecommitdiff
path: root/text_recognizer/data/base_data_module.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-06 17:42:53 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-06 17:42:53 +0200
commiteb5b206f7e1b08435378d2a02395307be55ee6f1 (patch)
tree0cd30234afab698eb632b20a7da97e3bc7e98882 /text_recognizer/data/base_data_module.py
parent4d1f2cef39688871d2caafce42a09316381a27ae (diff)
Refactoring data with attrs and refactor conf for hydra
Diffstat (limited to 'text_recognizer/data/base_data_module.py')
-rw-r--r--text_recognizer/data/base_data_module.py29
1 files changed, 16 insertions, 13 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index de5628f..18b1996 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,11 +1,13 @@
"""Base lightning DataModule class."""
from pathlib import Path
-from typing import Dict
+from typing import Any, Dict, Tuple
import attr
-import pytorch_lightning as LightningDataModule
+from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
+from text_recognizer.data.base_dataset import BaseDataset
+
def load_and_print_info(data_module_class: type) -> None:
"""Load dataset and print dataset information."""
@@ -19,17 +21,20 @@ def load_and_print_info(data_module_class: type) -> None:
class BaseDataModule(LightningDataModule):
"""Base PyTorch Lightning DataModule."""
- batch_size: int = attr.ib(default=16)
- num_workers: int = attr.ib(default=0)
-
def __attrs_pre_init__(self) -> None:
super().__init__()
- def __attrs_post_init__(self) -> None:
- # Placeholders for subclasses.
- self.dims = None
- self.output_dims = None
- self.mapping = None
+ batch_size: int = attr.ib(default=16)
+ num_workers: int = attr.ib(default=0)
+
+ # Placeholders
+ data_train: BaseDataset = attr.ib(init=False, default=None)
+ data_val: BaseDataset = attr.ib(init=False, default=None)
+ data_test: BaseDataset = attr.ib(init=False, default=None)
+ dims: Tuple[int, ...] = attr.ib(init=False, default=None)
+ output_dims: Tuple[int, ...] = attr.ib(init=False, default=None)
+ mapping: Any = attr.ib(init=False, default=None)
+ inverse_mapping: Dict[str, int] = attr.ib(init=False)
@classmethod
def data_dirname(cls) -> Path:
@@ -58,9 +63,7 @@ class BaseDataModule(LightningDataModule):
stage (Any): Variable to set splits.
"""
- self.data_train = None
- self.data_val = None
- self.data_test = None
+ pass
def train_dataloader(self) -> DataLoader:
"""Retun DataLoader for train data."""