summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py5
-rw-r--r--text_recognizer/data/iam_paragraphs.py23
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py2
3 files changed, 24 insertions, 6 deletions
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index df0c0e1..8b3a46c 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -1,5 +1,7 @@
"""IAM original and sythetic dataset class."""
import attr
+from typing import Optional, Tuple
+
from torch.utils.data import ConcatDataset
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
@@ -13,6 +15,7 @@ class IAMExtendedParagraphs(BaseDataModule):
augment: bool = attr.ib(default=True)
train_fraction: float = attr.ib(default=0.8)
word_pieces: bool = attr.ib(default=False)
+ resize: Optional[Tuple[int, int]] = attr.ib(default=None)
def __attrs_post_init__(self) -> None:
self.iam_paragraphs = IAMParagraphs(
@@ -22,6 +25,7 @@ class IAMExtendedParagraphs(BaseDataModule):
train_fraction=self.train_fraction,
augment=self.augment,
word_pieces=self.word_pieces,
+ resize=self.resize,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
mapping=self.mapping,
@@ -30,6 +34,7 @@ class IAMExtendedParagraphs(BaseDataModule):
train_fraction=self.train_fraction,
augment=self.augment,
word_pieces=self.word_pieces,
+ resize=self.resize,
)
self.dims = self.iam_paragraphs.dims
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 3509b92..262533f 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -40,6 +40,7 @@ class IAMParagraphs(BaseDataModule):
word_pieces: bool = attr.ib(default=False)
augment: bool = attr.ib(default=True)
train_fraction: float = attr.ib(default=0.8)
+ resize: Optional[Tuple[int, int]] = attr.ib(default=None)
# Placeholders
dims: Tuple[int, int, int] = attr.ib(
@@ -79,7 +80,9 @@ class IAMParagraphs(BaseDataModule):
def setup(self, stage: str = None) -> None:
"""Loads the data for training/testing."""
- def _load_dataset(split: str, augment: bool) -> BaseDataset:
+ def _load_dataset(
+ split: str, augment: bool, resize: Optional[Tuple[int, int]]
+ ) -> BaseDataset:
crops, labels = _load_processed_crops_and_labels(split)
data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
targets = convert_strings_to_labels(
@@ -90,7 +93,9 @@ class IAMParagraphs(BaseDataModule):
return BaseDataset(
data,
targets,
- transform=get_transform(image_shape=self.dims[1:], augment=augment),
+ transform=get_transform(
+ image_shape=self.dims[1:], augment=augment, resize=resize
+ ),
target_transform=get_target_transform(self.word_pieces),
)
@@ -98,13 +103,17 @@ class IAMParagraphs(BaseDataModule):
_validate_data_dims(input_dims=self.dims, output_dims=self.output_dims)
if stage == "fit" or stage is None:
- data_train = _load_dataset(split="train", augment=self.augment)
+ data_train = _load_dataset(
+ split="train", augment=self.augment, resize=self.resize
+ )
self.data_train, self.data_val = split_dataset(
dataset=data_train, fraction=self.train_fraction, seed=SEED
)
if stage == "test" or stage is None:
- self.data_test = _load_dataset(split="test", augment=False)
+ self.data_test = _load_dataset(
+ split="test", augment=False, resize=self.resize
+ )
def __repr__(self) -> str:
"""Return information about the dataset."""
@@ -260,7 +269,9 @@ def _load_processed_crops_and_labels(
return ordered_crops, ordered_labels
-def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose:
+def get_transform(
+ image_shape: Tuple[int, int], augment: bool, resize: Optional[Tuple[int, int]]
+) -> T.Compose:
"""Get transformations for images."""
if augment:
transforms_list = [
@@ -278,6 +289,8 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose:
]
else:
transforms_list = [T.CenterCrop(image_shape)]
+ if resize is not None:
+ transforms_list.append(T.Resize(resize, T.InterpolationMode.BILINEAR))
transforms_list.append(T.ToTensor())
return T.Compose(transforms_list)
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 24ca896..b9cf90d 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -84,7 +84,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
data,
targets,
transform=get_transform(
- image_shape=self.dims[1:], augment=self.augment
+ image_shape=self.dims[1:], augment=self.augment, resize=self.resize
),
target_transform=get_target_transform(self.word_pieces),
)