summaryrefslogtreecommitdiff
path: root/text_recognizer/datasets/base_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/datasets/base_dataset.py')
-rw-r--r--text_recognizer/datasets/base_dataset.py21
1 files changed, 12 insertions, 9 deletions
diff --git a/text_recognizer/datasets/base_dataset.py b/text_recognizer/datasets/base_dataset.py
index 7322d7f..a004b8d 100644
--- a/text_recognizer/datasets/base_dataset.py
+++ b/text_recognizer/datasets/base_dataset.py
@@ -17,12 +17,14 @@ class BaseDataset(Dataset):
target_transform (Callable): Fucntion that takes a target and applies
target transforms.
"""
- def __init__(self,
- data: Union[Sequence, Tensor],
- targets: Union[Sequence, Tensor],
- transform: Callable = None,
- target_transform: Callable = None,
- ) -> None:
+
+ def __init__(
+ self,
+ data: Union[Sequence, Tensor],
+ targets: Union[Sequence, Tensor],
+ transform: Callable = None,
+ target_transform: Callable = None,
+ ) -> None:
if len(data) != len(targets):
raise ValueError("Data and targets must be of equal length.")
self.data = data
@@ -30,11 +32,10 @@ class BaseDataset(Dataset):
self.transform = transform
self.target_transform = target_transform
-
def __len__(self) -> int:
"""Return the length of the dataset."""
return len(self.data)
-
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""Return a datum and its target, after processing by transforms.
@@ -56,7 +57,9 @@ class BaseDataset(Dataset):
return datum, target
-def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> Tensor:
+def convert_strings_to_labels(
+ strings: Sequence[str], mapping: Dict[str, int], length: int
+) -> Tensor:
"""
Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <S> and </S> tokens,
and padded wiht the <P> token.