From 5fca1bb6aeb48f106750d66b46317acf38cd8a10 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sat, 18 Sep 2021 17:42:14 +0200
Subject: Add resize attribute to IAM paragraphs

---
 text_recognizer/data/iam_extended_paragraphs.py  |  5 +++++
 text_recognizer/data/iam_paragraphs.py           | 23 ++++++++++++++++++-----
 text_recognizer/data/iam_synthetic_paragraphs.py |  2 +-
 3 files changed, 24 insertions(+), 6 deletions(-)

(limited to 'text_recognizer')

diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index df0c0e1..8b3a46c 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -1,5 +1,7 @@
 """IAM original and sythetic dataset class."""
 import attr
+from typing import Optional, Tuple
+
 from torch.utils.data import ConcatDataset
 
 from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
@@ -13,6 +15,7 @@ class IAMExtendedParagraphs(BaseDataModule):
     augment: bool = attr.ib(default=True)
     train_fraction: float = attr.ib(default=0.8)
     word_pieces: bool = attr.ib(default=False)
+    resize: Optional[Tuple[int, int]] = attr.ib(default=None)
 
     def __attrs_post_init__(self) -> None:
         self.iam_paragraphs = IAMParagraphs(
@@ -22,6 +25,7 @@ class IAMExtendedParagraphs(BaseDataModule):
             train_fraction=self.train_fraction,
             augment=self.augment,
             word_pieces=self.word_pieces,
+            resize=self.resize,
         )
         self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
             mapping=self.mapping,
@@ -30,6 +34,7 @@ class IAMExtendedParagraphs(BaseDataModule):
             train_fraction=self.train_fraction,
             augment=self.augment,
             word_pieces=self.word_pieces,
+            resize=self.resize,
         )
 
         self.dims = self.iam_paragraphs.dims
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 3509b92..262533f 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -40,6 +40,7 @@ class IAMParagraphs(BaseDataModule):
     word_pieces: bool = attr.ib(default=False)
     augment: bool = attr.ib(default=True)
     train_fraction: float = attr.ib(default=0.8)
+    resize: Optional[Tuple[int, int]] = attr.ib(default=None)
 
     # Placeholders
     dims: Tuple[int, int, int] = attr.ib(
@@ -79,7 +80,9 @@ class IAMParagraphs(BaseDataModule):
     def setup(self, stage: str = None) -> None:
         """Loads the data for training/testing."""
 
-        def _load_dataset(split: str, augment: bool) -> BaseDataset:
+        def _load_dataset(
+            split: str, augment: bool, resize: Optional[Tuple[int, int]]
+        ) -> BaseDataset:
             crops, labels = _load_processed_crops_and_labels(split)
             data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
             targets = convert_strings_to_labels(
@@ -90,7 +93,9 @@ class IAMParagraphs(BaseDataModule):
             return BaseDataset(
                 data,
                 targets,
-                transform=get_transform(image_shape=self.dims[1:], augment=augment),
+                transform=get_transform(
+                    image_shape=self.dims[1:], augment=augment, resize=resize
+                ),
                 target_transform=get_target_transform(self.word_pieces),
             )
 
@@ -98,13 +103,17 @@ class IAMParagraphs(BaseDataModule):
         _validate_data_dims(input_dims=self.dims, output_dims=self.output_dims)
 
         if stage == "fit" or stage is None:
-            data_train = _load_dataset(split="train", augment=self.augment)
+            data_train = _load_dataset(
+                split="train", augment=self.augment, resize=self.resize
+            )
             self.data_train, self.data_val = split_dataset(
                 dataset=data_train, fraction=self.train_fraction, seed=SEED
             )
 
         if stage == "test" or stage is None:
-            self.data_test = _load_dataset(split="test", augment=False)
+            self.data_test = _load_dataset(
+                split="test", augment=False, resize=self.resize
+            )
 
     def __repr__(self) -> str:
         """Return information about the dataset."""
@@ -260,7 +269,9 @@ def _load_processed_crops_and_labels(
     return ordered_crops, ordered_labels
 
 
-def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose:
+def get_transform(
+    image_shape: Tuple[int, int], augment: bool, resize: Optional[Tuple[int, int]]
+) -> T.Compose:
     """Get transformations for images."""
     if augment:
         transforms_list = [
@@ -278,6 +289,8 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose:
         ]
     else:
         transforms_list = [T.CenterCrop(image_shape)]
+    if resize is not None:
+        transforms_list.append(T.Resize(resize, T.InterpolationMode.BILINEAR))
     transforms_list.append(T.ToTensor())
     return T.Compose(transforms_list)
 
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 24ca896..b9cf90d 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -84,7 +84,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
                 data,
                 targets,
                 transform=get_transform(
-                    image_shape=self.dims[1:], augment=self.augment
+                    image_shape=self.dims[1:], augment=self.augment, resize=self.resize
                 ),
                 target_transform=get_target_transform(self.word_pieces),
             )
-- 
cgit v1.2.3-70-g09d2