diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-23 21:55:42 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-23 21:55:42 +0100 |
commit | ae589fb3ffdbf6c4bb1ae35345f7a3665deeebc5 (patch) | |
tree | 1702f74c069679ebdd74a03892275c6eb3a80ffd /text_recognizer/datasets/base_dataset.py | |
parent | e3741de333a3a43a7968241b6eccaaac66dd7b20 (diff) |
refactored emnist lines dataset
Diffstat (limited to 'text_recognizer/datasets/base_dataset.py')
-rw-r--r-- | text_recognizer/datasets/base_dataset.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/text_recognizer/datasets/base_dataset.py b/text_recognizer/datasets/base_dataset.py index a004b8d..a9e9c24 100644 --- a/text_recognizer/datasets/base_dataset.py +++ b/text_recognizer/datasets/base_dataset.py @@ -61,13 +61,13 @@ def convert_strings_to_labels( strings: Sequence[str], mapping: Dict[str, int], length: int ) -> Tensor: """ - Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <S> and </S> tokens, - and padded wiht the <P> token. + Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <s> and </s> tokens, + and padded wiht the <p> token. """ - labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<P>"] + labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<p>"] for i, string in enumerate(strings): tokens = list(string) - tokens = ["<S>", *tokens, "</S>"] + tokens = ["<s>", *tokens, "</s>"] for j, token in enumerate(tokens): labels[i, j] = mapping[token] return labels |