summaryrefslogtreecommitdiff
path: root/text_recognizer/data/base_data_module.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/base_data_module.py')
-rw-r--r--text_recognizer/data/base_data_module.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 408ae36..fd914b6 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,11 +1,12 @@
"""Base lightning DataModule class."""
from pathlib import Path
-from typing import Any, Dict, Tuple
+from typing import Dict, Tuple
import attr
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
+from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.data.base_dataset import BaseDataset
@@ -24,8 +25,10 @@ class BaseDataModule(LightningDataModule):
def __attrs_pre_init__(self) -> None:
super().__init__()
+ mapping: AbstractMapping = attr.ib()
batch_size: int = attr.ib(default=16)
num_workers: int = attr.ib(default=0)
+ pin_memory: bool = attr.ib(default=True)
# Placeholders
data_train: BaseDataset = attr.ib(init=False, default=None)
@@ -33,8 +36,6 @@ class BaseDataModule(LightningDataModule):
data_test: BaseDataset = attr.ib(init=False, default=None)
dims: Tuple[int, ...] = attr.ib(init=False, default=None)
output_dims: Tuple[int, ...] = attr.ib(init=False, default=None)
- mapping: Any = attr.ib(init=False, default=None)
- inverse_mapping: Dict[str, int] = attr.ib(init=False)
@classmethod
def data_dirname(cls) -> Path:
@@ -46,7 +47,6 @@ class BaseDataModule(LightningDataModule):
return {
"input_dim": self.dims,
"output_dims": self.output_dims,
- "mapping": self.mapping,
}
def prepare_data(self) -> None:
@@ -72,7 +72,7 @@ class BaseDataModule(LightningDataModule):
shuffle=True,
batch_size=self.batch_size,
num_workers=self.num_workers,
- pin_memory=True,
+ pin_memory=self.pin_memory,
)
def val_dataloader(self) -> DataLoader:
@@ -82,7 +82,7 @@ class BaseDataModule(LightningDataModule):
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
- pin_memory=True,
+ pin_memory=self.pin_memory,
)
def test_dataloader(self) -> DataLoader:
@@ -92,5 +92,5 @@ class BaseDataModule(LightningDataModule):
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
- pin_memory=True,
+ pin_memory=self.pin_memory,
)