From 75909723fa2b1f6245d5c5422e4f2e88b8a26052 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 15 Nov 2020 17:40:44 +0100 Subject: Able to generate support files for lines datasets. --- src/text_recognizer/datasets/dataset.py | 39 +++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 9 deletions(-) (limited to 'src/text_recognizer/datasets/dataset.py') diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py index 2de7f09..95063bc 100644 --- a/src/text_recognizer/datasets/dataset.py +++ b/src/text_recognizer/datasets/dataset.py @@ -1,11 +1,12 @@ """Abstract dataset class.""" -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torch import Tensor from torch.utils import data from torchvision.transforms import ToTensor +import text_recognizer.datasets.transforms as transforms from text_recognizer.datasets.util import EmnistMapper @@ -16,8 +17,8 @@ class Dataset(data.Dataset): self, train: bool, subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + transform: Optional[List[Dict]] = None, + target_transform: Optional[List[Dict]] = None, init_token: Optional[str] = None, pad_token: Optional[str] = None, eos_token: Optional[str] = None, @@ -27,8 +28,8 @@ class Dataset(data.Dataset): Args: train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. - transform (Optional[Callable]): Transform(s) for input data. Defaults to None. - target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. + transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None. + target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None. init_token (Optional[str]): String representing the start of sequence token. Defaults to None. pad_token (Optional[str]): String representing the pad token. Defaults to None. eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. @@ -53,14 +54,34 @@ class Dataset(data.Dataset): self.num_classes = self.mapper.num_classes # Set transforms. - self.transform = transform if transform is not None else ToTensor() - self.target_transform = ( - target_transform if target_transform is not None else torch.tensor - ) + self.transform = self._configure_transform(transform) + self.target_transform = self._configure_target_transform(target_transform) self._data = None self._targets = None + def _configure_transform(self, transform: List[Dict]) -> transforms.Compose: + transform_list = [] + if transform is not None: + for t in transform: + t_type = t["type"] + t_args = t["args"] or {} + transform_list.append(getattr(transforms, t_type)(**t_args)) + else: + transform_list.append(ToTensor()) + return transforms.Compose(transform_list) + + def _configure_target_transform( + self, target_transform: List[Dict] + ) -> transforms.Compose: + target_transform_list = [torch.tensor] + if target_transform is not None: + for t in target_transform: + t_type = t["type"] + t_args = t["args"] or {} + target_transform_list.append(getattr(transforms, t_type)(**t_args)) + return transforms.Compose(target_transform_list) + @property def data(self) -> Tensor: """The input data.""" -- cgit v1.2.3-70-g09d2