diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:03:11 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:03:11 +0200 |
commit | 30e3ae483c846418b04ed48f014a4af2cf9a0771 (patch) | |
tree | 08a309e14416e68ad351be8a3d48bf50efd80d6b /text_recognizer/data/base_data_module.py | |
parent | ad3f404d36a9add32992698dd083d368f3b96812 (diff) |
Update transforms in datamodule/set
Diffstat (limited to 'text_recognizer/data/base_data_module.py')
-rw-r--r-- | text_recognizer/data/base_data_module.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index ee70176..3add837 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -1,13 +1,13 @@ """Base lightning DataModule class.""" from pathlib import Path -from typing import Dict, Optional, Tuple, Type, TypeVar +from typing import Callable, Dict, Optional, Tuple, Type, TypeVar import attr from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader from text_recognizer.data.base_dataset import BaseDataset -from text_recognizer.data.base_mapping import AbstractMapping +from text_recognizer.data.mappings.base_mapping import AbstractMapping T = TypeVar("T") @@ -29,6 +29,10 @@ class BaseDataModule(LightningDataModule): super().__init__() mapping: Type[AbstractMapping] = attr.ib() + transform: Optional[Callable] = attr.ib(default=None) + test_transform: Optional[Callable] = attr.ib(default=None) + target_transform: Optional[Callable] = attr.ib(default=None) + train_fraction: float = attr.ib(default=0.8) batch_size: int = attr.ib(default=16) num_workers: int = attr.ib(default=0) pin_memory: bool = attr.ib(default=True) |