summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_lines.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_lines.py')
-rw-r--r--text_recognizer/data/iam_lines.py75
1 files changed, 12 insertions, 63 deletions
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 7a063c1..efd1cde 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -5,7 +5,6 @@ dataset.
"""
import json
from pathlib import Path
-import random
from typing import List, Sequence, Tuple
import attr
@@ -13,19 +12,17 @@ from loguru import logger as log
import numpy as np
from PIL import Image, ImageFile, ImageOps
from torch import Tensor
-import torchvision.transforms as T
-from torchvision.transforms.functional import InterpolationMode
-from text_recognizer.data import image_utils
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import (
BaseDataset,
convert_strings_to_labels,
split_dataset,
)
-from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
-from text_recognizer.data.iam_paragraphs import get_target_transform
+from text_recognizer.data.mappings.emnist_mapping import EmnistMapping
+from text_recognizer.data.utils import image_utils
+from text_recognizer.data.transforms.load_transform import load_transform_from_file
ImageFile.LOAD_TRUNCATED_IMAGES = True
@@ -42,9 +39,6 @@ MAX_WORD_PIECE_LENGTH = 72
class IAMLines(BaseDataModule):
"""IAM handwritten lines dataset."""
- word_pieces: bool = attr.ib(default=False)
- augment: bool = attr.ib(default=True)
- train_fraction: float = attr.ib(default=0.8)
dims: Tuple[int, int, int] = attr.ib(
init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
)
@@ -94,10 +88,8 @@ class IAMLines(BaseDataModule):
data_train = BaseDataset(
x_train,
y_train,
- transform=get_transform(IMAGE_WIDTH, self.augment),
- target_transform=get_target_transform(
- self.word_pieces, max_len=MAX_WORD_PIECE_LENGTH
- ),
+ transform=self.transform,
+ target_transform=self.target_transform,
)
self.data_train, self.data_val = split_dataset(
@@ -118,10 +110,8 @@ class IAMLines(BaseDataModule):
self.data_test = BaseDataset(
x_test,
y_test,
- transform=get_transform(IMAGE_WIDTH),
- target_transform=get_target_transform(
- self.word_pieces, max_len=MAX_WORD_PIECE_LENGTH
- ),
+ transform=self.test_transform,
+ target_transform=self.target_transform,
)
if stage is None:
@@ -147,6 +137,8 @@ class IAMLines(BaseDataModule):
x, y = next(iter(self.train_dataloader()))
xt, yt = next(iter(self.test_dataloader()))
+ x = x[0] if isinstance(x, list) else x
+ xt = xt[0] if isinstance(xt, list) else xt
data = (
"Train/val/test sizes: "
f"{len(self.data_train)}, "
@@ -217,51 +209,8 @@ def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, Li
return crops, labels
-def get_transform(image_width: int, augment: bool = False) -> T.Compose:
- """Augment with brigthness, rotation, slant, translation, scale, and noise."""
-
- def embed_crop(
- crop: Image, augment: bool = augment, image_width: int = image_width
- ) -> Image:
- # Crop is PIL.Image of dtype="L" (so value range is [0, 255])
- image = Image.new("L", (image_width, IMAGE_HEIGHT))
-
- # Resize crop.
- crop_width, crop_height = crop.size
- new_crop_height = IMAGE_HEIGHT
- new_crop_width = int(new_crop_height * (crop_width / crop_height))
-
- if augment:
- # Add random stretching
- new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
- new_crop_width = min(new_crop_width, image_width)
- crop_resized = crop.resize(
- (new_crop_width, new_crop_height), resample=Image.BILINEAR
- )
-
- # Embed in image
- x = min(28, image_width - new_crop_width)
- y = IMAGE_HEIGHT - new_crop_height
- image.paste(crop_resized, (x, y))
-
- return image
-
- transfroms_list = [T.Lambda(embed_crop)]
-
- if augment:
- transfroms_list += [
- T.ColorJitter(brightness=(0.8, 1.6)),
- T.RandomAffine(
- degrees=1,
- shear=(-30, 20),
- interpolation=InterpolationMode.BILINEAR,
- fill=0,
- ),
- ]
- transfroms_list.append(T.ToTensor())
- return T.Compose(transfroms_list)
-
-
def generate_iam_lines() -> None:
"""Displays Iam Lines dataset statistics."""
- load_and_print_info(IAMLines)
+ transform = load_transform_from_file("transform/iam_lines.yaml")
+ test_transform = load_transform_from_file("test_transform/iam_lines.yaml")
+ load_and_print_info(IAMLines(transform=transform, test_transform=test_transform))