diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:03:11 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:03:11 +0200 |
commit | 30e3ae483c846418b04ed48f014a4af2cf9a0771 (patch) | |
tree | 08a309e14416e68ad351be8a3d48bf50efd80d6b /text_recognizer/data/iam_synthetic_paragraphs.py | |
parent | ad3f404d36a9add32992698dd083d368f3b96812 (diff) |
Update transforms in datamodule/set
Diffstat (limited to 'text_recognizer/data/iam_synthetic_paragraphs.py')
-rw-r--r-- | text_recognizer/data/iam_synthetic_paragraphs.py | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index f253427..5718747 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -12,7 +12,6 @@ from text_recognizer.data.base_dataset import ( BaseDataset, convert_strings_to_labels, ) -from text_recognizer.data.emnist_mapping import EmnistMapping from text_recognizer.data.iam import IAM from text_recognizer.data.iam_lines import ( line_crops_and_labels, @@ -21,13 +20,13 @@ from text_recognizer.data.iam_lines import ( ) from text_recognizer.data.iam_paragraphs import ( get_dataset_properties, - get_target_transform, - get_transform, IAMParagraphs, IMAGE_SCALE_FACTOR, NEW_LINE_TOKEN, resize_image, ) +from text_recognizer.data.mappings.emnist_mapping import EmnistMapping +from text_recognizer.data.transforms.load_transform import load_transform_from_file PROCESSED_DATA_DIRNAME = ( @@ -83,10 +82,8 @@ class IAMSyntheticParagraphs(IAMParagraphs): self.data_train = BaseDataset( data, targets, - transform=get_transform( - image_shape=self.dims[1:], augment=self.augment, resize=self.resize - ), - target_transform=get_target_transform(self.word_pieces), + transform=self.transform, + target_transform=self.target_transforms, ) def __repr__(self) -> str: @@ -101,6 +98,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): return basic x, y = next(iter(self.train_dataloader())) + x = x[0] if isinstance(x, list) else x data = ( f"Train/val/test sizes: {len(self.data_train)}, 0, 0\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" @@ -220,4 +218,8 @@ def generate_random_batches( def create_synthetic_iam_paragraphs() -> None: """Creates and prints IAM Synthetic Paragraphs dataset.""" - load_and_print_info(IAMSyntheticParagraphs) + transform = load_transform_from_file("transform/paragraphs.yaml") + test_transform = load_transform_from_file("test_transform/paragraphs.yaml") + load_and_print_info( + IAMSyntheticParagraphs(transform=transform, test_transform=test_transform) + ) |