diff options
Diffstat (limited to 'text_recognizer/data/iam_lines.py')
-rw-r--r-- | text_recognizer/data/iam_lines.py | 15 |
1 files changed, 7 insertions, 8 deletions
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index e60d1ba..a0d9b59 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -19,8 +19,7 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.iam import IAM -from text_recognizer.data.mappings import EmnistMapping -from text_recognizer.data.transforms.load_transform import load_transform_from_file +from text_recognizer.data.tokenizer import Tokenizer from text_recognizer.data.stems.line import IamLinesStem from text_recognizer.data.utils import image_utils import text_recognizer.metadata.iam_lines as metadata @@ -33,7 +32,7 @@ class IAMLines(BaseDataModule): def __init__( self, - mapping: EmnistMapping, + tokenizer: Tokenizer, transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, @@ -43,7 +42,7 @@ class IAMLines(BaseDataModule): pin_memory: bool = True, ) -> None: super().__init__( - mapping, + tokenizer, transform, test_transform, target_transform, @@ -61,7 +60,7 @@ class IAMLines(BaseDataModule): return log.info("Cropping IAM lines regions...") - iam = IAM(mapping=EmnistMapping()) + iam = IAM(tokenizer=self.tokenizer) iam.prepare_data() crops_train, labels_train = line_crops_and_labels(iam, "train") crops_test, labels_test = line_crops_and_labels(iam, "test") @@ -100,7 +99,7 @@ class IAMLines(BaseDataModule): raise ValueError("Target length longer than max output length.") y_train = convert_strings_to_labels( - labels_train, self.mapping.inverse_mapping, length=self.output_dims[0] + labels_train, self.tokenizer.inverse_mapping, length=self.output_dims[0] ) data_train = BaseDataset( x_train, @@ -122,7 +121,7 @@ class IAMLines(BaseDataModule): raise ValueError("Taget length longer than max output length.") y_test = convert_strings_to_labels( - labels_test, self.mapping.inverse_mapping, length=self.output_dims[0] + labels_test, self.tokenizer.inverse_mapping, length=self.output_dims[0] ) self.data_test = BaseDataset( x_test, @@ -144,7 +143,7 @@ class IAMLines(BaseDataModule): """Return information about the dataset.""" basic = ( "IAM Lines dataset\n" - f"Num classes: {len(self.mapping)}\n" + f"Num classes: {len(self.tokenizer)}\n" f"Input dims: {self.dims}\n" f"Output dims: {self.output_dims}\n" ) |