diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 5 | ||||
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 23 | ||||
-rw-r--r-- | text_recognizer/data/iam_synthetic_paragraphs.py | 2 |
3 files changed, 24 insertions, 6 deletions
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index df0c0e1..8b3a46c 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -1,5 +1,7 @@ """IAM original and sythetic dataset class.""" import attr +from typing import Optional, Tuple + from torch.utils.data import ConcatDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info @@ -13,6 +15,7 @@ class IAMExtendedParagraphs(BaseDataModule): augment: bool = attr.ib(default=True) train_fraction: float = attr.ib(default=0.8) word_pieces: bool = attr.ib(default=False) + resize: Optional[Tuple[int, int]] = attr.ib(default=None) def __attrs_post_init__(self) -> None: self.iam_paragraphs = IAMParagraphs( @@ -22,6 +25,7 @@ class IAMExtendedParagraphs(BaseDataModule): train_fraction=self.train_fraction, augment=self.augment, word_pieces=self.word_pieces, + resize=self.resize, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( mapping=self.mapping, @@ -30,6 +34,7 @@ class IAMExtendedParagraphs(BaseDataModule): train_fraction=self.train_fraction, augment=self.augment, word_pieces=self.word_pieces, + resize=self.resize, ) self.dims = self.iam_paragraphs.dims diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 3509b92..262533f 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -40,6 +40,7 @@ class IAMParagraphs(BaseDataModule): word_pieces: bool = attr.ib(default=False) augment: bool = attr.ib(default=True) train_fraction: float = attr.ib(default=0.8) + resize: Optional[Tuple[int, int]] = attr.ib(default=None) # Placeholders dims: Tuple[int, int, int] = attr.ib( @@ -79,7 +80,9 @@ class IAMParagraphs(BaseDataModule): def setup(self, stage: str = None) -> None: """Loads the data for training/testing.""" - def _load_dataset(split: str, augment: bool) -> BaseDataset: + def _load_dataset( + split: str, augment: bool, resize: Optional[Tuple[int, int]] + ) -> BaseDataset: crops, labels = _load_processed_crops_and_labels(split) data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops] targets = convert_strings_to_labels( @@ -90,7 +93,9 @@ class IAMParagraphs(BaseDataModule): return BaseDataset( data, targets, - transform=get_transform(image_shape=self.dims[1:], augment=augment), + transform=get_transform( + image_shape=self.dims[1:], augment=augment, resize=resize + ), target_transform=get_target_transform(self.word_pieces), ) @@ -98,13 +103,17 @@ class IAMParagraphs(BaseDataModule): _validate_data_dims(input_dims=self.dims, output_dims=self.output_dims) if stage == "fit" or stage is None: - data_train = _load_dataset(split="train", augment=self.augment) + data_train = _load_dataset( + split="train", augment=self.augment, resize=self.resize + ) self.data_train, self.data_val = split_dataset( dataset=data_train, fraction=self.train_fraction, seed=SEED ) if stage == "test" or stage is None: - self.data_test = _load_dataset(split="test", augment=False) + self.data_test = _load_dataset( + split="test", augment=False, resize=self.resize + ) def __repr__(self) -> str: """Return information about the dataset.""" @@ -260,7 +269,9 @@ def _load_processed_crops_and_labels( return ordered_crops, ordered_labels -def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose: +def get_transform( + image_shape: Tuple[int, int], augment: bool, resize: Optional[Tuple[int, int]] +) -> T.Compose: """Get transformations for images.""" if augment: transforms_list = [ @@ -278,6 +289,8 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose: ] else: transforms_list = [T.CenterCrop(image_shape)] + if resize is not None: + transforms_list.append(T.Resize(resize, T.InterpolationMode.BILINEAR)) transforms_list.append(T.ToTensor()) return T.Compose(transforms_list) diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 24ca896..b9cf90d 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -84,7 +84,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): data, targets, transform=get_transform( - image_shape=self.dims[1:], augment=self.augment + image_shape=self.dims[1:], augment=self.augment, resize=self.resize ), target_transform=get_target_transform(self.word_pieces), ) |