summaryrefslogtreecommitdiff
path: root/text_recognizer/datasets/base_data_module.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/datasets/base_data_module.py')
-rw-r--r--text_recognizer/datasets/base_data_module.py36
1 files changed, 28 insertions, 8 deletions
diff --git a/text_recognizer/datasets/base_data_module.py b/text_recognizer/datasets/base_data_module.py
index 09a0a43..830b39b 100644
--- a/text_recognizer/datasets/base_data_module.py
+++ b/text_recognizer/datasets/base_data_module.py
@@ -16,7 +16,7 @@ def load_and_print_info(data_module_class: type) -> None:
class BaseDataModule(pl.LightningDataModule):
"""Base PyTorch Lightning DataModule."""
-
+
def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None:
super().__init__()
self.batch_size = batch_size
@@ -34,13 +34,17 @@ class BaseDataModule(pl.LightningDataModule):
def config(self) -> Dict:
"""Return important settings of the dataset."""
- return {"input_dim": self.dims, "output_dims": self.output_dims, "mapping": self.mapping}
+ return {
+ "input_dim": self.dims,
+ "output_dims": self.output_dims,
+ "mapping": self.mapping,
+ }
def prepare_data(self) -> None:
"""Prepare data for training."""
pass
- def setup(self, stage: Any = None) -> None:
+ def setup(self, stage: str = None) -> None:
"""Split into train, val, test, and set dims.
Should assign `torch Dataset` objects to self.data_train, self.data_val, and
@@ -54,16 +58,32 @@ class BaseDataModule(pl.LightningDataModule):
self.data_val = None
self.data_test = None
-
def train_dataloader(self) -> DataLoader:
"""Retun DataLoader for train data."""
- return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
+ return DataLoader(
+ self.data_train,
+ shuffle=True,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ )
def val_dataloader(self) -> DataLoader:
"""Return DataLoader for val data."""
- return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
+ return DataLoader(
+ self.data_val,
+ shuffle=False,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ )
def test_dataloader(self) -> DataLoader:
"""Return DataLoader for val data."""
- return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
-
+ return DataLoader(
+ self.data_test,
+ shuffle=False,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ )