summaryrefslogtreecommitdiff
path: root/text_recognizer/data/base_data_module.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-05 23:39:11 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-05 23:39:11 +0200
commit65df6a72b002c4b23d6f2eb545839e157f7f2aa0 (patch)
treed78df1d7143dc9ff9e29afd4fd6bc7490bc79418 /text_recognizer/data/base_data_module.py
parent8bc4b4cab00a2777a748c10fca9b3ee01e32277c (diff)
Remove attrs
Diffstat (limited to 'text_recognizer/data/base_data_module.py')
-rw-r--r--text_recognizer/data/base_data_module.py46
1 files changed, 26 insertions, 20 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index a0c8416..6306cf8 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -2,7 +2,6 @@
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
-from attrs import define, field
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
@@ -20,29 +19,36 @@ def load_and_print_info(data_module_class: type) -> None:
print(dataset)
-@define(repr=False)
class BaseDataModule(LightningDataModule):
"""Base PyTorch Lightning DataModule."""
- def __attrs_post_init__(self) -> None:
- """Pre init constructor."""
+ def __init__(
+ self,
+ mapping: Type[AbstractMapping],
+ transform: Optional[Callable] = None,
+ test_transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ train_fraction: float = 0.8,
+ batch_size: int = 16,
+ num_workers: int = 0,
+ pin_memory: bool = True,
+ ) -> None:
super().__init__()
-
- mapping: Type[AbstractMapping] = field()
- transform: Optional[Callable] = field(default=None)
- test_transform: Optional[Callable] = field(default=None)
- target_transform: Optional[Callable] = field(default=None)
- train_fraction: float = field(default=0.8)
- batch_size: int = field(default=16)
- num_workers: int = field(default=0)
- pin_memory: bool = field(default=True)
-
- # Placeholders
- data_train: BaseDataset = field(init=False, default=None)
- data_val: BaseDataset = field(init=False, default=None)
- data_test: BaseDataset = field(init=False, default=None)
- dims: Tuple[int, ...] = field(init=False, default=None)
- output_dims: Tuple[int, ...] = field(init=False, default=None)
+ self.mapping = mapping
+ self.transform = transform
+ self.test_transform = test_transform
+ self.target_transform = target_transform
+ self.train_fraction = train_fraction
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.pin_memory = pin_memory
+
+ # Placeholders
+ self.data_train: BaseDataset
+ self.data_val: BaseDataset
+ self.data_test: BaseDataset
+ self.dims: Tuple[int, ...]
+ self.output_dims: Tuple[int, ...]
@classmethod
def data_dirname(cls: T) -> Path: