diff options
Diffstat (limited to 'text_recognizer/data/iam_extended_paragraphs.py')
-rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 29 |
1 files changed, 16 insertions, 13 deletions
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 8b3a46c..87b8ef1 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -1,21 +1,17 @@ """IAM original and sythetic dataset class.""" import attr -from typing import Optional, Tuple from torch.utils.data import ConcatDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs +from text_recognizer.data.transforms.load_transform import load_transform_from_file @attr.s(auto_attribs=True, repr=False) class IAMExtendedParagraphs(BaseDataModule): - - augment: bool = attr.ib(default=True) - train_fraction: float = attr.ib(default=0.8) - word_pieces: bool = attr.ib(default=False) - resize: Optional[Tuple[int, int]] = attr.ib(default=None) + """A dataset with synthetic and real handwritten paragraph.""" def __attrs_post_init__(self) -> None: self.iam_paragraphs = IAMParagraphs( @@ -23,18 +19,18 @@ class IAMExtendedParagraphs(BaseDataModule): batch_size=self.batch_size, num_workers=self.num_workers, train_fraction=self.train_fraction, - augment=self.augment, - word_pieces=self.word_pieces, - resize=self.resize, + transform=self.transform, + test_transform=self.test_transform, + target_transform=self.target_transform, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( mapping=self.mapping, batch_size=self.batch_size, num_workers=self.num_workers, train_fraction=self.train_fraction, - augment=self.augment, - word_pieces=self.word_pieces, - resize=self.resize, + transform=self.transform, + test_transform=self.test_transform, + target_transform=self.target_transform, ) self.dims = self.iam_paragraphs.dims @@ -69,6 +65,8 @@ class IAMExtendedParagraphs(BaseDataModule): x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) + x = x[0] if isinstance(x, list) else x + xt = xt[0] if isinstance(xt, list) else xt data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" @@ -80,4 +78,9 @@ class IAMExtendedParagraphs(BaseDataModule): def show_dataset_info() -> None: - load_and_print_info(IAMExtendedParagraphs) + """Displays Iam extended dataset information.""" + transform = load_transform_from_file("transform/paragraphs.yaml") + test_transform = load_transform_from_file("test_transform/paragraphs_test.yaml") + load_and_print_info( + IAMExtendedParagraphs(transform=transform, test_transform=test_transform) + ) |