summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist_lines.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/emnist_lines.py')
-rw-r--r--text_recognizer/data/emnist_lines.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 88aac0d..8a31c44 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -12,7 +12,7 @@ from torch import Tensor
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels
from text_recognizer.data.emnist import EMNIST
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.tokenizer import Tokenizer
from text_recognizer.data.stems.line import LineStem
from text_recognizer.data.utils.sentence_generator import SentenceGenerator
import text_recognizer.metadata.emnist_lines as metadata
@@ -23,7 +23,7 @@ class EMNISTLines(BaseDataModule):
def __init__(
self,
- mapping: EmnistMapping,
+ tokenizer: Tokenizer,
transform: Optional[Callable] = None,
test_transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
@@ -39,7 +39,7 @@ class EMNISTLines(BaseDataModule):
num_test: int = 2_000,
) -> None:
super().__init__(
- mapping,
+ tokenizer,
transform,
test_transform,
target_transform,
@@ -120,7 +120,7 @@ class EMNISTLines(BaseDataModule):
"EMNISTLines2 Dataset\n" # pylint: disable=no-member
f"Min overlap: {self.min_overlap}\n"
f"Max overlap: {self.max_overlap}\n"
- f"Num classes: {len(self.mapping)}\n"
+ f"Num classes: {len(self.tokenizer)}\n"
f"Dims: {self.dims}\n"
f"Output dims: {self.output_dims}\n"
)
@@ -153,17 +153,17 @@ class EMNISTLines(BaseDataModule):
if split == "train":
samples_by_char = _get_samples_by_char(
- emnist.x_train, emnist.y_train, self.mapping.mapping
+ emnist.x_train, emnist.y_train, self.tokenizer.mapping
)
num = self.num_train
elif split == "val":
samples_by_char = _get_samples_by_char(
- emnist.x_train, emnist.y_train, self.mapping.mapping
+ emnist.x_train, emnist.y_train, self.tokenizer.mapping
)
num = self.num_val
else:
samples_by_char = _get_samples_by_char(
- emnist.x_test, emnist.y_test, self.mapping.mapping
+ emnist.x_test, emnist.y_test, self.tokenizer.mapping
)
num = self.num_test
@@ -178,7 +178,7 @@ class EMNISTLines(BaseDataModule):
self.dims,
)
y = convert_strings_to_labels(
- y, self.mapping.inverse_mapping, length=metadata.MAX_OUTPUT_LENGTH
+ y, self.tokenizer.inverse_mapping, length=metadata.MAX_OUTPUT_LENGTH
)
f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")