From e3741de333a3a43a7968241b6eccaaac66dd7b20 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 21 Mar 2021 22:33:58 +0100 Subject: Working on EMNIST Lines dataset --- text_recognizer/datasets/base_dataset.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) (limited to 'text_recognizer/datasets/base_dataset.py') diff --git a/text_recognizer/datasets/base_dataset.py b/text_recognizer/datasets/base_dataset.py index 7322d7f..a004b8d 100644 --- a/text_recognizer/datasets/base_dataset.py +++ b/text_recognizer/datasets/base_dataset.py @@ -17,12 +17,14 @@ class BaseDataset(Dataset): target_transform (Callable): Fucntion that takes a target and applies target transforms. """ - def __init__(self, - data: Union[Sequence, Tensor], - targets: Union[Sequence, Tensor], - transform: Callable = None, - target_transform: Callable = None, - ) -> None: + + def __init__( + self, + data: Union[Sequence, Tensor], + targets: Union[Sequence, Tensor], + transform: Callable = None, + target_transform: Callable = None, + ) -> None: if len(data) != len(targets): raise ValueError("Data and targets must be of equal length.") self.data = data @@ -30,11 +32,10 @@ class BaseDataset(Dataset): self.transform = transform self.target_transform = target_transform - def __len__(self) -> int: """Return the length of the dataset.""" return len(self.data) - + def __getitem__(self, index: int) -> Tuple[Any, Any]: """Return a datum and its target, after processing by transforms. @@ -56,7 +57,9 @@ class BaseDataset(Dataset): return datum, target -def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> Tensor: +def convert_strings_to_labels( + strings: Sequence[str], mapping: Dict[str, int], length: int +) -> Tensor: """ Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with and tokens, and padded wiht the

token. -- cgit v1.2.3-70-g09d2