summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/dataset.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-15 17:40:44 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-15 17:40:44 +0100
commit75909723fa2b1f6245d5c5422e4f2e88b8a26052 (patch)
treee60c37d05c724db011d75adf9313d93839d193ac /src/text_recognizer/datasets/dataset.py
parentcad676fc423efeafde65f03e4815248f2d357011 (diff)
Able to generate support files for lines datasets.
Diffstat (limited to 'src/text_recognizer/datasets/dataset.py')
-rw-r--r--src/text_recognizer/datasets/dataset.py39
1 files changed, 30 insertions, 9 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
index 2de7f09..95063bc 100644
--- a/src/text_recognizer/datasets/dataset.py
+++ b/src/text_recognizer/datasets/dataset.py
@@ -1,11 +1,12 @@
"""Abstract dataset class."""
-from typing import Callable, Dict, Optional, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.utils import data
from torchvision.transforms import ToTensor
+import text_recognizer.datasets.transforms as transforms
from text_recognizer.datasets.util import EmnistMapper
@@ -16,8 +17,8 @@ class Dataset(data.Dataset):
self,
train: bool,
subsample_fraction: float = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
+ transform: Optional[List[Dict]] = None,
+ target_transform: Optional[List[Dict]] = None,
init_token: Optional[str] = None,
pad_token: Optional[str] = None,
eos_token: Optional[str] = None,
@@ -27,8 +28,8 @@ class Dataset(data.Dataset):
Args:
train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
- transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
- target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
+ transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None.
+ target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None.
init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
pad_token (Optional[str]): String representing the pad token. Defaults to None.
eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
@@ -53,14 +54,34 @@ class Dataset(data.Dataset):
self.num_classes = self.mapper.num_classes
# Set transforms.
- self.transform = transform if transform is not None else ToTensor()
- self.target_transform = (
- target_transform if target_transform is not None else torch.tensor
- )
+ self.transform = self._configure_transform(transform)
+ self.target_transform = self._configure_target_transform(target_transform)
self._data = None
self._targets = None
+ def _configure_transform(self, transform: List[Dict]) -> transforms.Compose:
+ transform_list = []
+ if transform is not None:
+ for t in transform:
+ t_type = t["type"]
+ t_args = t["args"] or {}
+ transform_list.append(getattr(transforms, t_type)(**t_args))
+ else:
+ transform_list.append(ToTensor())
+ return transforms.Compose(transform_list)
+
+ def _configure_target_transform(
+ self, target_transform: List[Dict]
+ ) -> transforms.Compose:
+ target_transform_list = [torch.tensor]
+ if target_transform is not None:
+ for t in target_transform:
+ t_type = t["type"]
+ t_args = t["args"] or {}
+ target_transform_list.append(getattr(transforms, t_type)(**t_args))
+ return transforms.Compose(target_transform_list)
+
@property
def data(self) -> Tensor:
"""The input data."""