summaryrefslogtreecommitdiff
path: root/text_recognizer/data/base_dataset.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-03 00:31:00 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-03 00:31:00 +0200
commit3a21c29e2eff4378c63717f8920ca3ccbfef013c (patch)
treeba46504d7baa8d4fb5bfd473acf99a7a184b330c /text_recognizer/data/base_dataset.py
parent75eb34020620584247313926527019471411f6af (diff)
Lint files
Diffstat (limited to 'text_recognizer/data/base_dataset.py')
-rw-r--r--text_recognizer/data/base_dataset.py19
1 files changed, 14 insertions, 5 deletions
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index 8640d92..e08130d 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -9,8 +9,7 @@ from torch.utils.data import Dataset
@attr.s
class BaseDataset(Dataset):
- """
- Base Dataset class that processes data and targets through optional transfroms.
+ r"""Base Dataset class that processes data and targets through optional transfroms.
Args:
data (Union[Sequence, Tensor]): Torch tensors, numpy arrays, or PIL images.
@@ -26,9 +25,11 @@ class BaseDataset(Dataset):
target_transform: Optional[Callable] = attr.ib(default=None)
def __attrs_pre_init__(self) -> None:
+ """Pre init constructor."""
super().__init__()
def __attrs_post_init__(self) -> None:
+ """Post init constructor."""
if len(self.data) != len(self.targets):
raise ValueError("Data and targets must be of equal length.")
@@ -60,9 +61,17 @@ class BaseDataset(Dataset):
def convert_strings_to_labels(
strings: Sequence[str], mapping: Dict[str, int], length: int
) -> Tensor:
- """
- Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <s> and </s> tokens,
- and padded wiht the <p> token.
+ r"""Convert a sequence of N strings to (N, length) ndarray.
+
+ Add each string with <s> and </s> tokens, and padded wiht the <p> token.
+
+ Args:
+ strings (Sequence[str]): Sequence of strings.
+ mapping (Dict[str, int]): Mapping of characters and digits to integers.
+ length (int): Max lenght of all strings.
+
+ Returns:
+ Tensor: Target with emnist mapping indices.
"""
labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<p>"]
for i, string in enumerate(strings):