diff options
Diffstat (limited to 'text_recognizer/data/iam_extended_paragraphs.py')
-rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 5 |
1 files changed, 5 insertions, 0 deletions
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index df0c0e1..8b3a46c 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -1,5 +1,7 @@ """IAM original and sythetic dataset class.""" import attr +from typing import Optional, Tuple + from torch.utils.data import ConcatDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info @@ -13,6 +15,7 @@ class IAMExtendedParagraphs(BaseDataModule): augment: bool = attr.ib(default=True) train_fraction: float = attr.ib(default=0.8) word_pieces: bool = attr.ib(default=False) + resize: Optional[Tuple[int, int]] = attr.ib(default=None) def __attrs_post_init__(self) -> None: self.iam_paragraphs = IAMParagraphs( @@ -22,6 +25,7 @@ class IAMExtendedParagraphs(BaseDataModule): train_fraction=self.train_fraction, augment=self.augment, word_pieces=self.word_pieces, + resize=self.resize, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( mapping=self.mapping, @@ -30,6 +34,7 @@ class IAMExtendedParagraphs(BaseDataModule): train_fraction=self.train_fraction, augment=self.augment, word_pieces=self.word_pieces, + resize=self.resize, ) self.dims = self.iam_paragraphs.dims |