diff options
Diffstat (limited to 'text_recognizer/datasets/base_data_module.py')
-rw-r--r-- | text_recognizer/datasets/base_data_module.py | 36 |
1 files changed, 28 insertions, 8 deletions
diff --git a/text_recognizer/datasets/base_data_module.py b/text_recognizer/datasets/base_data_module.py index 09a0a43..830b39b 100644 --- a/text_recognizer/datasets/base_data_module.py +++ b/text_recognizer/datasets/base_data_module.py @@ -16,7 +16,7 @@ def load_and_print_info(data_module_class: type) -> None: class BaseDataModule(pl.LightningDataModule): """Base PyTorch Lightning DataModule.""" - + def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: super().__init__() self.batch_size = batch_size @@ -34,13 +34,17 @@ class BaseDataModule(pl.LightningDataModule): def config(self) -> Dict: """Return important settings of the dataset.""" - return {"input_dim": self.dims, "output_dims": self.output_dims, "mapping": self.mapping} + return { + "input_dim": self.dims, + "output_dims": self.output_dims, + "mapping": self.mapping, + } def prepare_data(self) -> None: """Prepare data for training.""" pass - def setup(self, stage: Any = None) -> None: + def setup(self, stage: str = None) -> None: """Split into train, val, test, and set dims. Should assign `torch Dataset` objects to self.data_train, self.data_val, and @@ -54,16 +58,32 @@ class BaseDataModule(pl.LightningDataModule): self.data_val = None self.data_test = None - def train_dataloader(self) -> DataLoader: """Retun DataLoader for train data.""" - return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) + return DataLoader( + self.data_train, + shuffle=True, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) def val_dataloader(self) -> DataLoader: """Return DataLoader for val data.""" - return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) + return DataLoader( + self.data_val, + shuffle=False, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) def test_dataloader(self) -> DataLoader: """Return DataLoader for val data.""" - return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) - + return DataLoader( + self.data_test, + shuffle=False, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) |