summaryrefslogtreecommitdiff
path: root/text_recognizer/data/base_data_module.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:03:11 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:03:11 +0200
commit30e3ae483c846418b04ed48f014a4af2cf9a0771 (patch)
tree08a309e14416e68ad351be8a3d48bf50efd80d6b /text_recognizer/data/base_data_module.py
parentad3f404d36a9add32992698dd083d368f3b96812 (diff)
Update transforms in datamodule/set
Diffstat (limited to 'text_recognizer/data/base_data_module.py')
-rw-r--r--text_recognizer/data/base_data_module.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index ee70176..3add837 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,13 +1,13 @@
"""Base lightning DataModule class."""
from pathlib import Path
-from typing import Dict, Optional, Tuple, Type, TypeVar
+from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
import attr
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from text_recognizer.data.base_dataset import BaseDataset
-from text_recognizer.data.base_mapping import AbstractMapping
+from text_recognizer.data.mappings.base_mapping import AbstractMapping
T = TypeVar("T")
@@ -29,6 +29,10 @@ class BaseDataModule(LightningDataModule):
super().__init__()
mapping: Type[AbstractMapping] = attr.ib()
+ transform: Optional[Callable] = attr.ib(default=None)
+ test_transform: Optional[Callable] = attr.ib(default=None)
+ target_transform: Optional[Callable] = attr.ib(default=None)
+ train_fraction: float = attr.ib(default=0.8)
batch_size: int = attr.ib(default=16)
num_workers: int = attr.ib(default=0)
pin_memory: bool = attr.ib(default=True)