summaryrefslogtreecommitdiff
path: root/text_recognizer/data/base_data_module.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/base_data_module.py')
-rw-r--r--text_recognizer/data/base_data_module.py11
1 files changed, 3 insertions, 8 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 7863333..bd6fd99 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -6,7 +6,7 @@ from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from text_recognizer.data.base_dataset import BaseDataset
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.tokenizer import Tokenizer
T = TypeVar("T")
@@ -24,7 +24,7 @@ class BaseDataModule(LightningDataModule):
def __init__(
self,
- mapping: EmnistMapping,
+ tokenizer: Tokenizer,
transform: Optional[Callable] = None,
test_transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
@@ -34,7 +34,7 @@ class BaseDataModule(LightningDataModule):
pin_memory: bool = True,
) -> None:
super().__init__()
- self.mapping = mapping
+ self.tokenizer = tokenizer
self.transform = transform
self.test_transform = test_transform
self.target_transform = target_transform
@@ -50,11 +50,6 @@ class BaseDataModule(LightningDataModule):
self.dims: Tuple[int, ...] = None
self.output_dims: Tuple[int, ...] = None
- @classmethod
- def data_dirname(cls: T) -> Path:
- """Return the path to the base data directory."""
- return Path(__file__).resolve().parents[2] / "data"
-
def config(self) -> Dict:
"""Return important settings of the dataset."""
return {