summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_lines.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_lines.py')
-rw-r--r--text_recognizer/data/iam_lines.py15
1 files changed, 7 insertions, 8 deletions
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index e60d1ba..a0d9b59 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -19,8 +19,7 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.iam import IAM
-from text_recognizer.data.mappings import EmnistMapping
-from text_recognizer.data.transforms.load_transform import load_transform_from_file
+from text_recognizer.data.tokenizer import Tokenizer
from text_recognizer.data.stems.line import IamLinesStem
from text_recognizer.data.utils import image_utils
import text_recognizer.metadata.iam_lines as metadata
@@ -33,7 +32,7 @@ class IAMLines(BaseDataModule):
def __init__(
self,
- mapping: EmnistMapping,
+ tokenizer: Tokenizer,
transform: Optional[Callable] = None,
test_transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
@@ -43,7 +42,7 @@ class IAMLines(BaseDataModule):
pin_memory: bool = True,
) -> None:
super().__init__(
- mapping,
+ tokenizer,
transform,
test_transform,
target_transform,
@@ -61,7 +60,7 @@ class IAMLines(BaseDataModule):
return
log.info("Cropping IAM lines regions...")
- iam = IAM(mapping=EmnistMapping())
+ iam = IAM(tokenizer=self.tokenizer)
iam.prepare_data()
crops_train, labels_train = line_crops_and_labels(iam, "train")
crops_test, labels_test = line_crops_and_labels(iam, "test")
@@ -100,7 +99,7 @@ class IAMLines(BaseDataModule):
raise ValueError("Target length longer than max output length.")
y_train = convert_strings_to_labels(
- labels_train, self.mapping.inverse_mapping, length=self.output_dims[0]
+ labels_train, self.tokenizer.inverse_mapping, length=self.output_dims[0]
)
data_train = BaseDataset(
x_train,
@@ -122,7 +121,7 @@ class IAMLines(BaseDataModule):
raise ValueError("Taget length longer than max output length.")
y_test = convert_strings_to_labels(
- labels_test, self.mapping.inverse_mapping, length=self.output_dims[0]
+ labels_test, self.tokenizer.inverse_mapping, length=self.output_dims[0]
)
self.data_test = BaseDataset(
x_test,
@@ -144,7 +143,7 @@ class IAMLines(BaseDataModule):
"""Return information about the dataset."""
basic = (
"IAM Lines dataset\n"
- f"Num classes: {len(self.mapping)}\n"
+ f"Num classes: {len(self.tokenizer)}\n"
f"Input dims: {self.dims}\n"
f"Output dims: {self.output_dims}\n"
)