diff options
Diffstat (limited to 'text_recognizer/data')
-rw-r--r-- | text_recognizer/data/base_data_module.py | 14 | ||||
-rw-r--r-- | text_recognizer/data/base_dataset.py | 24 |
2 files changed, 21 insertions, 17 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 8b5c188..de5628f 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -2,7 +2,8 @@ from pathlib import Path from typing import Dict -import pytorch_lightning as pl +import attr +import pytorch_lightning as LightningDataModule from torch.utils.data import DataLoader @@ -14,14 +15,17 @@ def load_and_print_info(data_module_class: type) -> None: print(dataset) -class BaseDataModule(pl.LightningDataModule): +@attr.s +class BaseDataModule(LightningDataModule): """Base PyTorch Lightning DataModule.""" - def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: + batch_size: int = attr.ib(default=16) + num_workers: int = attr.ib(default=0) + + def __attrs_pre_init__(self) -> None: super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers + def __attrs_post_init__(self) -> None: # Placeholders for subclasses. self.dims = None self.output_dims = None diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index 8d644d4..4318dfb 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -1,11 +1,13 @@ """Base PyTorch Dataset class.""" from typing import Any, Callable, Dict, Sequence, Tuple, Union +import attr import torch from torch import Tensor from torch.utils.data import Dataset +@attr.s class BaseDataset(Dataset): """ Base Dataset class that processes data and targets through optional transfroms. @@ -18,19 +20,17 @@ class BaseDataset(Dataset): target transforms. """ - def __init__( - self, - data: Union[Sequence, Tensor], - targets: Union[Sequence, Tensor], - transform: Callable = None, - target_transform: Callable = None, - ) -> None: - if len(data) != len(targets): + data: Union[Sequence, Tensor] = attr.ib() + targets: Union[Sequence, Tensor] = attr.ib() + transform: Callable = attr.ib() + target_transform: Callable = attr.ib() + + def __attrs_pre_init__(self) -> None: + super().__init__() + + def __attrs_post_init__(self) -> None: + if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") - self.data = data - self.targets = targets - self.transform = transform - self.target_transform = target_transform def __len__(self) -> int: """Return the length of the dataset.""" |