summaryrefslogtreecommitdiff
path: root/text_recognizer/datasets/base_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/datasets/base_dataset.py')
-rw-r--r--text_recognizer/datasets/base_dataset.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/text_recognizer/datasets/base_dataset.py b/text_recognizer/datasets/base_dataset.py
index a004b8d..a9e9c24 100644
--- a/text_recognizer/datasets/base_dataset.py
+++ b/text_recognizer/datasets/base_dataset.py
@@ -61,13 +61,13 @@ def convert_strings_to_labels(
strings: Sequence[str], mapping: Dict[str, int], length: int
) -> Tensor:
"""
- Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <S> and </S> tokens,
- and padded wiht the <P> token.
+ Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <s> and </s> tokens,
+ and padded wiht the <p> token.
"""
- labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<P>"]
+ labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<p>"]
for i, string in enumerate(strings):
tokens = list(string)
- tokens = ["<S>", *tokens, "</S>"]
+ tokens = ["<s>", *tokens, "</s>"]
for j, token in enumerate(tokens):
labels[i, j] = mapping[token]
return labels