summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_paragraphs.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:03:11 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:03:11 +0200
commit30e3ae483c846418b04ed48f014a4af2cf9a0771 (patch)
tree08a309e14416e68ad351be8a3d48bf50efd80d6b /text_recognizer/data/iam_paragraphs.py
parentad3f404d36a9add32992698dd083d368f3b96812 (diff)
Update transforms in datamodule/set
Diffstat (limited to 'text_recognizer/data/iam_paragraphs.py')
-rw-r--r--text_recognizer/data/iam_paragraphs.py68
1 files changed, 17 insertions, 51 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 254c7f5..26674e0 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -8,7 +8,6 @@ from loguru import logger as log
import numpy as np
from PIL import Image, ImageOps
import torchvision.transforms as T
-from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
@@ -17,9 +16,9 @@ from text_recognizer.data.base_dataset import (
convert_strings_to_labels,
split_dataset,
)
-from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
-from text_recognizer.data.transforms import WordPiece
+from text_recognizer.data.mappings.emnist_mapping import EmnistMapping
+from text_recognizer.data.transforms.load_transform import load_transform_from_file
PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs"
@@ -38,11 +37,6 @@ MAX_WORD_PIECE_LENGTH = 451
class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
- word_pieces: bool = attr.ib(default=False)
- augment: bool = attr.ib(default=True)
- train_fraction: float = attr.ib(default=0.8)
- resize: Optional[Tuple[int, int]] = attr.ib(default=None)
-
# Placeholders
dims: Tuple[int, int, int] = attr.ib(
init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
@@ -82,7 +76,7 @@ class IAMParagraphs(BaseDataModule):
"""Loads the data for training/testing."""
def _load_dataset(
- split: str, augment: bool, resize: Optional[Tuple[int, int]]
+ split: str, transform: T.Compose, target_transform: T.Compose
) -> BaseDataset:
crops, labels = _load_processed_crops_and_labels(split)
data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
@@ -92,12 +86,7 @@ class IAMParagraphs(BaseDataModule):
length=self.output_dims[0],
)
return BaseDataset(
- data,
- targets,
- transform=get_transform(
- image_shape=self.dims[1:], augment=augment, resize=resize
- ),
- target_transform=get_target_transform(self.word_pieces),
+ data, targets, transform=transform, target_transform=target_transform,
)
log.info(f"Loading IAM paragraph regions and lines for {stage}...")
@@ -105,7 +94,9 @@ class IAMParagraphs(BaseDataModule):
if stage == "fit" or stage is None:
data_train = _load_dataset(
- split="train", augment=self.augment, resize=self.resize
+ split="train",
+ transform=self.transform,
+ target_transform=self.target_transform,
)
self.data_train, self.data_val = split_dataset(
dataset=data_train, fraction=self.train_fraction, seed=SEED
@@ -113,7 +104,9 @@ class IAMParagraphs(BaseDataModule):
if stage == "test" or stage is None:
self.data_test = _load_dataset(
- split="test", augment=False, resize=self.resize
+ split="test",
+ transform=self.test_transform,
+ target_transform=self.target_transform,
)
def __repr__(self) -> str:
@@ -130,6 +123,8 @@ class IAMParagraphs(BaseDataModule):
x, y = next(iter(self.train_dataloader()))
xt, yt = next(iter(self.test_dataloader()))
+ x = x[0] if isinstance(x, list) else x
+ xt = xt[0] if isinstance(xt, list) else xt
data = (
"Train/val/test sizes: "
f"{len(self.data_train)}, "
@@ -274,39 +269,6 @@ def _load_processed_crops_and_labels(
return ordered_crops, ordered_labels
-def get_transform(
- image_shape: Tuple[int, int], augment: bool, resize: Optional[Tuple[int, int]]
-) -> T.Compose:
- """Get transformations for images."""
- if augment:
- transforms_list = [
- T.RandomCrop(
- size=image_shape,
- padding=None,
- pad_if_needed=True,
- fill=0,
- padding_mode="constant",
- ),
- T.ColorJitter(brightness=(0.8, 1.6)),
- T.RandomAffine(
- degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
- ),
- ]
- else:
- transforms_list = [T.CenterCrop(image_shape)]
- if resize is not None:
- transforms_list.append(T.Resize(resize, T.InterpolationMode.BILINEAR))
- transforms_list.append(T.ToTensor())
- return T.Compose(transforms_list)
-
-
-def get_target_transform(
- word_pieces: bool, max_len: int = MAX_WORD_PIECE_LENGTH
-) -> Optional[T.Compose]:
- """Transform emnist characters to word pieces."""
- return T.Compose([WordPiece(max_len=max_len)]) if word_pieces else None
-
-
def _labels_filename(split: str) -> Path:
"""Return filename of processed labels."""
return PROCESSED_DATA_DIRNAME / split / "_labels.json"
@@ -324,4 +286,8 @@ def _num_lines(label: str) -> int:
def create_iam_paragraphs() -> None:
"""Loads and displays dataset statistics."""
- load_and_print_info(IAMParagraphs)
+ transform = load_transform_from_file("transform/paragraphs.yaml")
+ test_transform = load_transform_from_file("test_transform/paragraphs_test.yaml")
+ load_and_print_info(
+ IAMParagraphs(transform=transform, test_transform=test_transform)
+ )