summaryrefslogtreecommitdiff
path: root/text_recognizer/data/base_data_module.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/base_data_module.py')
-rw-r--r--text_recognizer/data/base_data_module.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index ee70176..3add837 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,13 +1,13 @@
"""Base lightning DataModule class."""
from pathlib import Path
-from typing import Dict, Optional, Tuple, Type, TypeVar
+from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
import attr
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from text_recognizer.data.base_dataset import BaseDataset
-from text_recognizer.data.base_mapping import AbstractMapping
+from text_recognizer.data.mappings.base_mapping import AbstractMapping
T = TypeVar("T")
@@ -29,6 +29,10 @@ class BaseDataModule(LightningDataModule):
super().__init__()
mapping: Type[AbstractMapping] = attr.ib()
+ transform: Optional[Callable] = attr.ib(default=None)
+ test_transform: Optional[Callable] = attr.ib(default=None)
+ target_transform: Optional[Callable] = attr.ib(default=None)
+ train_fraction: float = attr.ib(default=0.8)
batch_size: int = attr.ib(default=16)
num_workers: int = attr.ib(default=0)
pin_memory: bool = attr.ib(default=True)