From 30e3ae483c846418b04ed48f014a4af2cf9a0771 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 10 Oct 2021 18:03:11 +0200 Subject: Update transforms in datamodule/set --- text_recognizer/data/iam_synthetic_paragraphs.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) (limited to 'text_recognizer/data/iam_synthetic_paragraphs.py') diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index f253427..5718747 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -12,7 +12,6 @@ from text_recognizer.data.base_dataset import ( BaseDataset, convert_strings_to_labels, ) -from text_recognizer.data.emnist_mapping import EmnistMapping from text_recognizer.data.iam import IAM from text_recognizer.data.iam_lines import ( line_crops_and_labels, @@ -21,13 +20,13 @@ from text_recognizer.data.iam_lines import ( ) from text_recognizer.data.iam_paragraphs import ( get_dataset_properties, - get_target_transform, - get_transform, IAMParagraphs, IMAGE_SCALE_FACTOR, NEW_LINE_TOKEN, resize_image, ) +from text_recognizer.data.mappings.emnist_mapping import EmnistMapping +from text_recognizer.data.transforms.load_transform import load_transform_from_file PROCESSED_DATA_DIRNAME = ( @@ -83,10 +82,8 @@ class IAMSyntheticParagraphs(IAMParagraphs): self.data_train = BaseDataset( data, targets, - transform=get_transform( - image_shape=self.dims[1:], augment=self.augment, resize=self.resize - ), - target_transform=get_target_transform(self.word_pieces), + transform=self.transform, + target_transform=self.target_transforms, ) def __repr__(self) -> str: @@ -101,6 +98,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): return basic x, y = next(iter(self.train_dataloader())) + x = x[0] if isinstance(x, list) else x data = ( f"Train/val/test sizes: {len(self.data_train)}, 0, 0\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" @@ -220,4 +218,8 @@ def generate_random_batches( def create_synthetic_iam_paragraphs() -> None: """Creates and prints IAM Synthetic Paragraphs dataset.""" - load_and_print_info(IAMSyntheticParagraphs) + transform = load_transform_from_file("transform/paragraphs.yaml") + test_transform = load_transform_from_file("test_transform/paragraphs.yaml") + load_and_print_info( + IAMSyntheticParagraphs(transform=transform, test_transform=test_transform) + ) -- cgit v1.2.3-70-g09d2