diff options
Diffstat (limited to 'text_recognizer/data/base_dataset.py')
-rw-r--r-- | text_recognizer/data/base_dataset.py | 19 |
1 files changed, 14 insertions, 5 deletions
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index 8640d92..e08130d 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -9,8 +9,7 @@ from torch.utils.data import Dataset @attr.s class BaseDataset(Dataset): - """ - Base Dataset class that processes data and targets through optional transfroms. + r"""Base Dataset class that processes data and targets through optional transfroms. Args: data (Union[Sequence, Tensor]): Torch tensors, numpy arrays, or PIL images. @@ -26,9 +25,11 @@ class BaseDataset(Dataset): target_transform: Optional[Callable] = attr.ib(default=None) def __attrs_pre_init__(self) -> None: + """Pre init constructor.""" super().__init__() def __attrs_post_init__(self) -> None: + """Post init constructor.""" if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") @@ -60,9 +61,17 @@ class BaseDataset(Dataset): def convert_strings_to_labels( strings: Sequence[str], mapping: Dict[str, int], length: int ) -> Tensor: - """ - Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <s> and </s> tokens, - and padded wiht the <p> token. + r"""Convert a sequence of N strings to (N, length) ndarray. + + Add each string with <s> and </s> tokens, and padded wiht the <p> token. + + Args: + strings (Sequence[str]): Sequence of strings. + mapping (Dict[str, int]): Mapping of characters and digits to integers. + length (int): Max lenght of all strings. + + Returns: + Tensor: Target with emnist mapping indices. """ labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<p>"] for i, string in enumerate(strings): |