summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_paragraphs.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_paragraphs.py')
-rw-r--r--text_recognizer/data/iam_paragraphs.py23
1 files changed, 18 insertions, 5 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 3509b92..262533f 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -40,6 +40,7 @@ class IAMParagraphs(BaseDataModule):
word_pieces: bool = attr.ib(default=False)
augment: bool = attr.ib(default=True)
train_fraction: float = attr.ib(default=0.8)
+ resize: Optional[Tuple[int, int]] = attr.ib(default=None)
# Placeholders
dims: Tuple[int, int, int] = attr.ib(
@@ -79,7 +80,9 @@ class IAMParagraphs(BaseDataModule):
def setup(self, stage: str = None) -> None:
"""Loads the data for training/testing."""
- def _load_dataset(split: str, augment: bool) -> BaseDataset:
+ def _load_dataset(
+ split: str, augment: bool, resize: Optional[Tuple[int, int]]
+ ) -> BaseDataset:
crops, labels = _load_processed_crops_and_labels(split)
data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
targets = convert_strings_to_labels(
@@ -90,7 +93,9 @@ class IAMParagraphs(BaseDataModule):
return BaseDataset(
data,
targets,
- transform=get_transform(image_shape=self.dims[1:], augment=augment),
+ transform=get_transform(
+ image_shape=self.dims[1:], augment=augment, resize=resize
+ ),
target_transform=get_target_transform(self.word_pieces),
)
@@ -98,13 +103,17 @@ class IAMParagraphs(BaseDataModule):
_validate_data_dims(input_dims=self.dims, output_dims=self.output_dims)
if stage == "fit" or stage is None:
- data_train = _load_dataset(split="train", augment=self.augment)
+ data_train = _load_dataset(
+ split="train", augment=self.augment, resize=self.resize
+ )
self.data_train, self.data_val = split_dataset(
dataset=data_train, fraction=self.train_fraction, seed=SEED
)
if stage == "test" or stage is None:
- self.data_test = _load_dataset(split="test", augment=False)
+ self.data_test = _load_dataset(
+ split="test", augment=False, resize=self.resize
+ )
def __repr__(self) -> str:
"""Return information about the dataset."""
@@ -260,7 +269,9 @@ def _load_processed_crops_and_labels(
return ordered_crops, ordered_labels
-def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose:
+def get_transform(
+ image_shape: Tuple[int, int], augment: bool, resize: Optional[Tuple[int, int]]
+) -> T.Compose:
"""Get transformations for images."""
if augment:
transforms_list = [
@@ -278,6 +289,8 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose:
]
else:
transforms_list = [T.CenterCrop(image_shape)]
+ if resize is not None:
+ transforms_list.append(T.Resize(resize, T.InterpolationMode.BILINEAR))
transforms_list.append(T.ToTensor())
return T.Compose(transforms_list)