summaryrefslogtreecommitdiff
path: root/text_recognizer/data/base_data_module.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/base_data_module.py')
-rw-r--r--text_recognizer/data/base_data_module.py32
1 files changed, 16 insertions, 16 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 15c286a..a0c8416 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -2,7 +2,7 @@
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
-import attr
+from attrs import define, field
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
@@ -20,29 +20,29 @@ def load_and_print_info(data_module_class: type) -> None:
print(dataset)
-@attr.s(repr=False)
+@define(repr=False)
class BaseDataModule(LightningDataModule):
"""Base PyTorch Lightning DataModule."""
- def __attrs_pre_init__(self) -> None:
+ def __attrs_post_init__(self) -> None:
"""Pre init constructor."""
super().__init__()
- mapping: Type[AbstractMapping] = attr.ib()
- transform: Optional[Callable] = attr.ib(default=None)
- test_transform: Optional[Callable] = attr.ib(default=None)
- target_transform: Optional[Callable] = attr.ib(default=None)
- train_fraction: float = attr.ib(default=0.8)
- batch_size: int = attr.ib(default=16)
- num_workers: int = attr.ib(default=0)
- pin_memory: bool = attr.ib(default=True)
+ mapping: Type[AbstractMapping] = field()
+ transform: Optional[Callable] = field(default=None)
+ test_transform: Optional[Callable] = field(default=None)
+ target_transform: Optional[Callable] = field(default=None)
+ train_fraction: float = field(default=0.8)
+ batch_size: int = field(default=16)
+ num_workers: int = field(default=0)
+ pin_memory: bool = field(default=True)
# Placeholders
- data_train: BaseDataset = attr.ib(init=False, default=None)
- data_val: BaseDataset = attr.ib(init=False, default=None)
- data_test: BaseDataset = attr.ib(init=False, default=None)
- dims: Tuple[int, ...] = attr.ib(init=False, default=None)
- output_dims: Tuple[int, ...] = attr.ib(init=False, default=None)
+ data_train: BaseDataset = field(init=False, default=None)
+ data_val: BaseDataset = field(init=False, default=None)
+ data_test: BaseDataset = field(init=False, default=None)
+ dims: Tuple[int, ...] = field(init=False, default=None)
+ output_dims: Tuple[int, ...] = field(init=False, default=None)
@classmethod
def data_dirname(cls: T) -> Path: