diff options
Diffstat (limited to 'text_recognizer/datasets/base_dataset.py')
-rw-r--r-- | text_recognizer/datasets/base_dataset.py | 21 |
1 files changed, 12 insertions, 9 deletions
diff --git a/text_recognizer/datasets/base_dataset.py b/text_recognizer/datasets/base_dataset.py index 7322d7f..a004b8d 100644 --- a/text_recognizer/datasets/base_dataset.py +++ b/text_recognizer/datasets/base_dataset.py @@ -17,12 +17,14 @@ class BaseDataset(Dataset): target_transform (Callable): Fucntion that takes a target and applies target transforms. """ - def __init__(self, - data: Union[Sequence, Tensor], - targets: Union[Sequence, Tensor], - transform: Callable = None, - target_transform: Callable = None, - ) -> None: + + def __init__( + self, + data: Union[Sequence, Tensor], + targets: Union[Sequence, Tensor], + transform: Callable = None, + target_transform: Callable = None, + ) -> None: if len(data) != len(targets): raise ValueError("Data and targets must be of equal length.") self.data = data @@ -30,11 +32,10 @@ class BaseDataset(Dataset): self.transform = transform self.target_transform = target_transform - def __len__(self) -> int: """Return the length of the dataset.""" return len(self.data) - + def __getitem__(self, index: int) -> Tuple[Any, Any]: """Return a datum and its target, after processing by transforms. @@ -56,7 +57,9 @@ class BaseDataset(Dataset): return datum, target -def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> Tensor: +def convert_strings_to_labels( + strings: Sequence[str], mapping: Dict[str, int], length: int +) -> Tensor: """ Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <S> and </S> tokens, and padded wiht the <P> token. |