summaryrefslogtreecommitdiff
path: root/text_recognizer/datasets/base_data_module.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/datasets/base_data_module.py')
-rw-r--r--text_recognizer/datasets/base_data_module.py89
1 files changed, 0 insertions, 89 deletions
diff --git a/text_recognizer/datasets/base_data_module.py b/text_recognizer/datasets/base_data_module.py
deleted file mode 100644
index f5e7300..0000000
--- a/text_recognizer/datasets/base_data_module.py
+++ /dev/null
@@ -1,89 +0,0 @@
-"""Base lightning DataModule class."""
-from pathlib import Path
-from typing import Dict
-
-import pytorch_lightning as pl
-from torch.utils.data import DataLoader
-
-
-def load_and_print_info(data_module_class: type) -> None:
- """Load EMNISTLines and prints info."""
- dataset = data_module_class()
- dataset.prepare_data()
- dataset.setup()
- print(dataset)
-
-
-class BaseDataModule(pl.LightningDataModule):
- """Base PyTorch Lightning DataModule."""
-
- def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None:
- super().__init__()
- self.batch_size = batch_size
- self.num_workers = num_workers
-
- # Placeholders for subclasses.
- self.dims = None
- self.output_dims = None
- self.mapping = None
-
- @classmethod
- def data_dirname(cls) -> Path:
- """Return the path to the base data directory."""
- return Path(__file__).resolve().parents[2] / "data"
-
- def config(self) -> Dict:
- """Return important settings of the dataset."""
- return {
- "input_dim": self.dims,
- "output_dims": self.output_dims,
- "mapping": self.mapping,
- }
-
- def prepare_data(self) -> None:
- """Prepare data for training."""
- pass
-
- def setup(self, stage: str = None) -> None:
- """Split into train, val, test, and set dims.
-
- Should assign `torch Dataset` objects to self.data_train, self.data_val, and
- optionally self.data_test.
-
- Args:
- stage (Any): Variable to set splits.
-
- """
- self.data_train = None
- self.data_val = None
- self.data_test = None
-
- def train_dataloader(self) -> DataLoader:
- """Retun DataLoader for train data."""
- return DataLoader(
- self.data_train,
- shuffle=True,
- batch_size=self.batch_size,
- num_workers=self.num_workers,
- pin_memory=True,
- )
-
- def val_dataloader(self) -> DataLoader:
- """Return DataLoader for val data."""
- return DataLoader(
- self.data_val,
- shuffle=False,
- batch_size=self.batch_size,
- num_workers=self.num_workers,
- pin_memory=True,
- )
-
- def test_dataloader(self) -> DataLoader:
- """Return DataLoader for val data."""
- return DataLoader(
- self.data_test,
- shuffle=False,
- batch_size=self.batch_size,
- num_workers=self.num_workers,
- pin_memory=True,
- )