diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-05 23:05:25 +0200 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-05 23:05:25 +0200 | 
| commit | 4d1f2cef39688871d2caafce42a09316381a27ae (patch) | |
| tree | 0f4385969e7df6d7d313cd5910bde9a7475ca027 /text_recognizer/data/base_data_module.py | |
| parent | f0481decdad9afb52494e9e95996deef843ef233 (diff) | |
Refactor with attr, working on cnn+transformer network
Diffstat (limited to 'text_recognizer/data/base_data_module.py')
| -rw-r--r-- | text_recognizer/data/base_data_module.py | 14 | 
1 files changed, 9 insertions, 5 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 8b5c188..de5628f 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -2,7 +2,8 @@  from pathlib import Path  from typing import Dict -import pytorch_lightning as pl +import attr +import pytorch_lightning as LightningDataModule  from torch.utils.data import DataLoader @@ -14,14 +15,17 @@ def load_and_print_info(data_module_class: type) -> None:      print(dataset) -class BaseDataModule(pl.LightningDataModule): +@attr.s +class BaseDataModule(LightningDataModule):      """Base PyTorch Lightning DataModule.""" -    def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: +    batch_size: int = attr.ib(default=16) +    num_workers: int = attr.ib(default=0) + +    def __attrs_pre_init__(self) -> None:          super().__init__() -        self.batch_size = batch_size -        self.num_workers = num_workers +    def __attrs_post_init__(self) -> None:          # Placeholders for subclasses.          self.dims = None          self.output_dims = None  |