summaryrefslogtreecommitdiff
path: root/text_recognizer/data/base_data_module.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-02 21:13:48 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-02 21:13:48 +0200
commit75801019981492eedf9280cb352eea3d8e99b65f (patch)
tree6521cc4134459e42591b2375f70acd348741474e /text_recognizer/data/base_data_module.py
parente5eca28438cd17d436359f2c6eee0bb9e55d2a8b (diff)
Fix log import, fix mapping in datamodules, fix nn modules can be hashed
Diffstat (limited to 'text_recognizer/data/base_data_module.py')
-rw-r--r--text_recognizer/data/base_data_module.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 408ae36..fd914b6 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,11 +1,12 @@
"""Base lightning DataModule class."""
from pathlib import Path
-from typing import Any, Dict, Tuple
+from typing import Dict, Tuple
import attr
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
+from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.data.base_dataset import BaseDataset
@@ -24,8 +25,10 @@ class BaseDataModule(LightningDataModule):
def __attrs_pre_init__(self) -> None:
super().__init__()
+ mapping: AbstractMapping = attr.ib()
batch_size: int = attr.ib(default=16)
num_workers: int = attr.ib(default=0)
+ pin_memory: bool = attr.ib(default=True)
# Placeholders
data_train: BaseDataset = attr.ib(init=False, default=None)
@@ -33,8 +36,6 @@ class BaseDataModule(LightningDataModule):
data_test: BaseDataset = attr.ib(init=False, default=None)
dims: Tuple[int, ...] = attr.ib(init=False, default=None)
output_dims: Tuple[int, ...] = attr.ib(init=False, default=None)
- mapping: Any = attr.ib(init=False, default=None)
- inverse_mapping: Dict[str, int] = attr.ib(init=False)
@classmethod
def data_dirname(cls) -> Path:
@@ -46,7 +47,6 @@ class BaseDataModule(LightningDataModule):
return {
"input_dim": self.dims,
"output_dims": self.output_dims,
- "mapping": self.mapping,
}
def prepare_data(self) -> None:
@@ -72,7 +72,7 @@ class BaseDataModule(LightningDataModule):
shuffle=True,
batch_size=self.batch_size,
num_workers=self.num_workers,
- pin_memory=True,
+ pin_memory=self.pin_memory,
)
def val_dataloader(self) -> DataLoader:
@@ -82,7 +82,7 @@ class BaseDataModule(LightningDataModule):
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
- pin_memory=True,
+ pin_memory=self.pin_memory,
)
def test_dataloader(self) -> DataLoader:
@@ -92,5 +92,5 @@ class BaseDataModule(LightningDataModule):
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
- pin_memory=True,
+ pin_memory=self.pin_memory,
)