summaryrefslogtreecommitdiff
path: root/text_recognizer/data/base_data_module.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-24 22:15:54 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-24 22:15:54 +0100
commit8248f173132dfb7e47ec62b08e9235990c8626e3 (patch)
tree2f3ff85602cbc08b7168bf4f0d3924d32a689852 /text_recognizer/data/base_data_module.py
parent74c907a17379688967dc4b3f41a44ba83034f5e0 (diff)
renamed datasets to data, added iam refactor
Diffstat (limited to 'text_recognizer/data/base_data_module.py')
-rw-r--r--text_recognizer/data/base_data_module.py89
1 files changed, 89 insertions, 0 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
new file mode 100644
index 0000000..f5e7300
--- /dev/null
+++ b/text_recognizer/data/base_data_module.py
@@ -0,0 +1,89 @@
+"""Base lightning DataModule class."""
+from pathlib import Path
+from typing import Dict
+
+import pytorch_lightning as pl
+from torch.utils.data import DataLoader
+
+
+def load_and_print_info(data_module_class: type) -> None:
+ """Load EMNISTLines and prints info."""
+ dataset = data_module_class()
+ dataset.prepare_data()
+ dataset.setup()
+ print(dataset)
+
+
+class BaseDataModule(pl.LightningDataModule):
+ """Base PyTorch Lightning DataModule."""
+
+ def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None:
+ super().__init__()
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+
+ # Placeholders for subclasses.
+ self.dims = None
+ self.output_dims = None
+ self.mapping = None
+
+ @classmethod
+ def data_dirname(cls) -> Path:
+ """Return the path to the base data directory."""
+ return Path(__file__).resolve().parents[2] / "data"
+
+ def config(self) -> Dict:
+ """Return important settings of the dataset."""
+ return {
+ "input_dim": self.dims,
+ "output_dims": self.output_dims,
+ "mapping": self.mapping,
+ }
+
+ def prepare_data(self) -> None:
+ """Prepare data for training."""
+ pass
+
+ def setup(self, stage: str = None) -> None:
+ """Split into train, val, test, and set dims.
+
+ Should assign `torch Dataset` objects to self.data_train, self.data_val, and
+ optionally self.data_test.
+
+ Args:
+ stage (Any): Variable to set splits.
+
+ """
+ self.data_train = None
+ self.data_val = None
+ self.data_test = None
+
+ def train_dataloader(self) -> DataLoader:
+ """Retun DataLoader for train data."""
+ return DataLoader(
+ self.data_train,
+ shuffle=True,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ )
+
+ def val_dataloader(self) -> DataLoader:
+ """Return DataLoader for val data."""
+ return DataLoader(
+ self.data_val,
+ shuffle=False,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ )
+
+ def test_dataloader(self) -> DataLoader:
+ """Return DataLoader for val data."""
+ return DataLoader(
+ self.data_test,
+ shuffle=False,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ )