diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-21 20:03:10 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-21 20:03:10 +0100 |
commit | aac452a2dc008338cb543549652da293c14b6b4e (patch) | |
tree | 6d018841e28f22eee5289f6cc59c28100a9d023d /text_recognizer/datasets/base_data_module.py | |
parent | a3a40c9c0118039460d5c9fba6a74edc0cdba106 (diff) |
Refactor EMNIST dataset
Diffstat (limited to 'text_recognizer/datasets/base_data_module.py')
-rw-r--r-- | text_recognizer/datasets/base_data_module.py | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/text_recognizer/datasets/base_data_module.py b/text_recognizer/datasets/base_data_module.py new file mode 100644 index 0000000..09a0a43 --- /dev/null +++ b/text_recognizer/datasets/base_data_module.py @@ -0,0 +1,69 @@ +"""Base lightning DataModule class.""" +from pathlib import Path +from typing import Dict + +import pytorch_lightning as pl +from torch.utils.data import DataLoader + + +def load_and_print_info(data_module_class: type) -> None: + """Load EMNISTLines and prints info.""" + dataset = data_module_class() + dataset.prepare_data() + dataset.setup() + print(dataset) + + +class BaseDataModule(pl.LightningDataModule): + """Base PyTorch Lightning DataModule.""" + + def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: + super().__init__() + self.batch_size = batch_size + self.num_workers = num_workers + + # Placeholders for subclasses. + self.dims = None + self.output_dims = None + self.mapping = None + + @classmethod + def data_dirname(cls) -> Path: + """Return the path to the base data directory.""" + return Path(__file__).resolve().parents[2] / "data" + + def config(self) -> Dict: + """Return important settings of the dataset.""" + return {"input_dim": self.dims, "output_dims": self.output_dims, "mapping": self.mapping} + + def prepare_data(self) -> None: + """Prepare data for training.""" + pass + + def setup(self, stage: Any = None) -> None: + """Split into train, val, test, and set dims. + + Should assign `torch Dataset` objects to self.data_train, self.data_val, and + optionally self.data_test. + + Args: + stage (Any): Variable to set splits. + + """ + self.data_train = None + self.data_val = None + self.data_test = None + + + def train_dataloader(self) -> DataLoader: + """Retun DataLoader for train data.""" + return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) + + def val_dataloader(self) -> DataLoader: + """Return DataLoader for val data.""" + return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) + + def test_dataloader(self) -> DataLoader: + """Return DataLoader for val data.""" + return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) + |