summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:03:11 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:03:11 +0200
commit30e3ae483c846418b04ed48f014a4af2cf9a0771 (patch)
tree08a309e14416e68ad351be8a3d48bf50efd80d6b
parentad3f404d36a9add32992698dd083d368f3b96812 (diff)
Update transforms in datamodule/set
-rw-r--r--text_recognizer/data/base_data_module.py8
-rw-r--r--text_recognizer/data/base_dataset.py22
-rw-r--r--text_recognizer/data/emnist.py16
-rw-r--r--text_recognizer/data/emnist_lines.py55
-rw-r--r--text_recognizer/data/iam.py2
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py29
-rw-r--r--text_recognizer/data/iam_lines.py75
-rw-r--r--text_recognizer/data/iam_paragraphs.py68
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py18
9 files changed, 104 insertions, 189 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index ee70176..3add837 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,13 +1,13 @@
"""Base lightning DataModule class."""
from pathlib import Path
-from typing import Dict, Optional, Tuple, Type, TypeVar
+from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
import attr
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from text_recognizer.data.base_dataset import BaseDataset
-from text_recognizer.data.base_mapping import AbstractMapping
+from text_recognizer.data.mappings.base_mapping import AbstractMapping
T = TypeVar("T")
@@ -29,6 +29,10 @@ class BaseDataModule(LightningDataModule):
super().__init__()
mapping: Type[AbstractMapping] = attr.ib()
+ transform: Optional[Callable] = attr.ib(default=None)
+ test_transform: Optional[Callable] = attr.ib(default=None)
+ target_transform: Optional[Callable] = attr.ib(default=None)
+ train_fraction: float = attr.ib(default=0.8)
batch_size: int = attr.ib(default=16)
num_workers: int = attr.ib(default=0)
pin_memory: bool = attr.ib(default=True)
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index e08130d..b9567c7 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -6,6 +6,8 @@ import torch
from torch import Tensor
from torch.utils.data import Dataset
+from text_recognizer.data.transforms.load_transform import load_transform_from_file
+
@attr.s
class BaseDataset(Dataset):
@@ -21,8 +23,8 @@ class BaseDataset(Dataset):
data: Union[Sequence, Tensor] = attr.ib()
targets: Union[Sequence, Tensor] = attr.ib()
- transform: Optional[Callable] = attr.ib(default=None)
- target_transform: Optional[Callable] = attr.ib(default=None)
+ transform: Union[Optional[Callable], str] = attr.ib(default=None)
+ target_transform: Union[Optional[Callable], str] = attr.ib(default=None)
def __attrs_pre_init__(self) -> None:
"""Pre init constructor."""
@@ -32,19 +34,31 @@ class BaseDataset(Dataset):
"""Post init constructor."""
if len(self.data) != len(self.targets):
raise ValueError("Data and targets must be of equal length.")
+ self.transform = self._load_transform(self.transform)
+ self.target_transform = self._load_transform(self.target_transform)
+
+ @staticmethod
+ def _load_transform(
+ transform: Union[Optional[Callable], str]
+ ) -> Optional[Callable]:
+ if isinstance(transform, str):
+ return load_transform_from_file(transform)
+ return transform
def __len__(self) -> int:
"""Return the length of the dataset."""
return len(self.data)
- def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
+ def __getitem__(
+ self, index: int
+ ) -> Tuple[Union[Tensor, Tuple[Tensor, Tensor]], Tensor]:
"""Return a datum and its target, after processing by transforms.
Args:
index (int): Index of a datum in the dataset.
Returns:
- Tuple[Tensor, Tensor]: Datum and target pair.
+ Tuple[Union[Tensor, Tuple[Tensor, Tensor]], Tensor]: Datum and target pair.
"""
datum, target = self.data[index], self.targets[index]
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 9ec6efe..e2bc5b9 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -3,7 +3,7 @@ import json
import os
from pathlib import Path
import shutil
-from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple
+from typing import Dict, List, Optional, Sequence, Set, Tuple
import zipfile
import attr
@@ -11,14 +11,14 @@ import h5py
from loguru import logger as log
import numpy as np
import toml
-import torchvision.transforms as T
from text_recognizer.data.base_data_module import (
BaseDataModule,
load_and_print_info,
)
from text_recognizer.data.base_dataset import BaseDataset, split_dataset
-from text_recognizer.data.download_utils import download_dataset
+from text_recognizer.data.utils.download_utils import download_dataset
+from text_recognizer.data.transforms.load_transform import load_transform_from_file
SEED = 4711
@@ -30,7 +30,9 @@ METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist"
PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist"
PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
-ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json"
+ESSENTIALS_FILENAME = (
+ Path(__file__).parents[0].resolve() / "mappings" / "emnist_essentials.json"
+)
@attr.s(auto_attribs=True)
@@ -46,9 +48,6 @@ class EMNIST(BaseDataModule):
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
- train_fraction: float = attr.ib(default=0.8)
- transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()]))
-
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
self.dims = (1, *self.mapping.input_size)
@@ -226,4 +225,5 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
def download_emnist() -> None:
"""Download dataset from internet, if it does not exists, and displays info."""
- load_and_print_info(EMNIST)
+ transform = load_transform_from_file("transform/default.yaml")
+ load_and_print_info(EMNIST(transform=transform, test_transfrom=transform))
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 3ff8a54..1a64931 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -1,7 +1,7 @@
"""Dataset of generated text from EMNIST characters."""
from collections import defaultdict
from pathlib import Path
-from typing import Callable, List, Tuple
+from typing import DefaultDict, List, Tuple
import attr
import h5py
@@ -9,8 +9,7 @@ from loguru import logger as log
import numpy as np
import torch
from torch import Tensor
-from torchvision import transforms
-from torchvision.transforms.functional import InterpolationMode
+import torchvision.transforms as T
from text_recognizer.data.base_data_module import (
BaseDataModule,
@@ -18,12 +17,13 @@ from text_recognizer.data.base_data_module import (
)
from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels
from text_recognizer.data.emnist import EMNIST
-from text_recognizer.data.sentence_generator import SentenceGenerator
+from text_recognizer.data.utils.sentence_generator import SentenceGenerator
+from text_recognizer.data.transforms.load_transform import load_transform_from_file
DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines"
ESSENTIALS_FILENAME = (
- Path(__file__).parents[0].resolve() / "emnist_lines_essentials.json"
+ Path(__file__).parents[0].resolve() / "mappings" / "emnist_lines_essentials.json"
)
SEED = 4711
@@ -37,7 +37,6 @@ MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST."""
- augment: bool = attr.ib(default=True)
max_length: int = attr.ib(default=128)
min_overlap: float = attr.ib(default=0.0)
max_overlap: float = attr.ib(default=0.33)
@@ -98,21 +97,15 @@ class EMNISTLines(BaseDataModule):
x_val = f["x_val"][:]
y_val = torch.LongTensor(f["y_val"][:])
- self.data_train = BaseDataset(
- x_train, y_train, transform=_get_transform(augment=self.augment)
- )
- self.data_val = BaseDataset(
- x_val, y_val, transform=_get_transform(augment=self.augment)
- )
+ self.data_train = BaseDataset(x_train, y_train, transform=self.transform)
+ self.data_val = BaseDataset(x_val, y_val, transform=self.transform)
if stage == "test" or stage is None:
with h5py.File(self.data_filename, "r") as f:
x_test = f["x_test"][:]
y_test = torch.LongTensor(f["y_test"][:])
- self.data_test = BaseDataset(
- x_test, y_test, transform=_get_transform(augment=False)
- )
+ self.data_test = BaseDataset(x_test, y_test, transform=self.test_transform)
def __repr__(self) -> str:
"""Return str about dataset."""
@@ -129,6 +122,7 @@ class EMNISTLines(BaseDataModule):
return basic
x, y = next(iter(self.train_dataloader()))
+ x = x[0] if isinstance(x, list) else x
data = (
"Train/val/test sizes: "
f"{len(self.data_train)}, "
@@ -184,7 +178,7 @@ class EMNISTLines(BaseDataModule):
def _get_samples_by_char(
samples: np.ndarray, labels: np.ndarray, mapping: List
-) -> defaultdict:
+) -> DefaultDict:
samples_by_char = defaultdict(list)
for sample, label in zip(samples, labels):
samples_by_char[mapping[label]].append(sample)
@@ -192,7 +186,7 @@ def _get_samples_by_char(
def _select_letter_samples_for_string(
- string: str, samples_by_char: defaultdict
+ string: str, samples_by_char: DefaultDict
) -> List[Tensor]:
null_image = torch.zeros((28, 28), dtype=torch.uint8)
sample_image_by_char = {}
@@ -207,7 +201,7 @@ def _select_letter_samples_for_string(
def _construct_image_from_string(
string: str,
- samples_by_char: defaultdict,
+ samples_by_char: DefaultDict,
min_overlap: float,
max_overlap: float,
width: int,
@@ -226,7 +220,7 @@ def _construct_image_from_string(
def _create_dataset_of_images(
num_samples: int,
- samples_by_char: defaultdict,
+ samples_by_char: DefaultDict,
sentence_generator: SentenceGenerator,
min_overlap: float,
max_overlap: float,
@@ -246,25 +240,8 @@ def _create_dataset_of_images(
return images, labels
-def _get_transform(augment: bool = False) -> Callable:
- if not augment:
- return transforms.Compose([transforms.ToTensor()])
- return transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.ColorJitter(brightness=(0.5, 1.0)),
- transforms.RandomAffine(
- degrees=3,
- translate=(0.0, 0.05),
- scale=(0.4, 1.1),
- shear=(-40, 50),
- interpolation=InterpolationMode.BILINEAR,
- fill=0,
- ),
- ]
- )
-
-
def generate_emnist_lines() -> None:
"""Generates a synthetic handwritten dataset and displays info."""
- load_and_print_info(EMNISTLines)
+ transform = load_transform_from_file("transform/emnist_lines.yaml")
+ test_transform = load_transform_from_file("test_transform/default.yaml")
+ load_and_print_info(EMNISTLines(transform=transform, test_transform=test_transform))
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index 263bf8e..766f3e0 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -15,7 +15,7 @@ from loguru import logger as log
import toml
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.download_utils import download_dataset
+from text_recognizer.data.utils.download_utils import download_dataset
RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "iam"
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 8b3a46c..87b8ef1 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -1,21 +1,17 @@
"""IAM original and sythetic dataset class."""
import attr
-from typing import Optional, Tuple
from torch.utils.data import ConcatDataset
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.iam_paragraphs import IAMParagraphs
from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
+from text_recognizer.data.transforms.load_transform import load_transform_from_file
@attr.s(auto_attribs=True, repr=False)
class IAMExtendedParagraphs(BaseDataModule):
-
- augment: bool = attr.ib(default=True)
- train_fraction: float = attr.ib(default=0.8)
- word_pieces: bool = attr.ib(default=False)
- resize: Optional[Tuple[int, int]] = attr.ib(default=None)
+ """A dataset with synthetic and real handwritten paragraph."""
def __attrs_post_init__(self) -> None:
self.iam_paragraphs = IAMParagraphs(
@@ -23,18 +19,18 @@ class IAMExtendedParagraphs(BaseDataModule):
batch_size=self.batch_size,
num_workers=self.num_workers,
train_fraction=self.train_fraction,
- augment=self.augment,
- word_pieces=self.word_pieces,
- resize=self.resize,
+ transform=self.transform,
+ test_transform=self.test_transform,
+ target_transform=self.target_transform,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
mapping=self.mapping,
batch_size=self.batch_size,
num_workers=self.num_workers,
train_fraction=self.train_fraction,
- augment=self.augment,
- word_pieces=self.word_pieces,
- resize=self.resize,
+ transform=self.transform,
+ test_transform=self.test_transform,
+ target_transform=self.target_transform,
)
self.dims = self.iam_paragraphs.dims
@@ -69,6 +65,8 @@ class IAMExtendedParagraphs(BaseDataModule):
x, y = next(iter(self.train_dataloader()))
xt, yt = next(iter(self.test_dataloader()))
+ x = x[0] if isinstance(x, list) else x
+ xt = xt[0] if isinstance(xt, list) else xt
data = (
f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
@@ -80,4 +78,9 @@ class IAMExtendedParagraphs(BaseDataModule):
def show_dataset_info() -> None:
- load_and_print_info(IAMExtendedParagraphs)
+ """Displays Iam extended dataset information."""
+ transform = load_transform_from_file("transform/paragraphs.yaml")
+ test_transform = load_transform_from_file("test_transform/paragraphs_test.yaml")
+ load_and_print_info(
+ IAMExtendedParagraphs(transform=transform, test_transform=test_transform)
+ )
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 7a063c1..efd1cde 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -5,7 +5,6 @@ dataset.
"""
import json
from pathlib import Path
-import random
from typing import List, Sequence, Tuple
import attr
@@ -13,19 +12,17 @@ from loguru import logger as log
import numpy as np
from PIL import Image, ImageFile, ImageOps
from torch import Tensor
-import torchvision.transforms as T
-from torchvision.transforms.functional import InterpolationMode
-from text_recognizer.data import image_utils
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import (
BaseDataset,
convert_strings_to_labels,
split_dataset,
)
-from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
-from text_recognizer.data.iam_paragraphs import get_target_transform
+from text_recognizer.data.mappings.emnist_mapping import EmnistMapping
+from text_recognizer.data.utils import image_utils
+from text_recognizer.data.transforms.load_transform import load_transform_from_file
ImageFile.LOAD_TRUNCATED_IMAGES = True
@@ -42,9 +39,6 @@ MAX_WORD_PIECE_LENGTH = 72
class IAMLines(BaseDataModule):
"""IAM handwritten lines dataset."""
- word_pieces: bool = attr.ib(default=False)
- augment: bool = attr.ib(default=True)
- train_fraction: float = attr.ib(default=0.8)
dims: Tuple[int, int, int] = attr.ib(
init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
)
@@ -94,10 +88,8 @@ class IAMLines(BaseDataModule):
data_train = BaseDataset(
x_train,
y_train,
- transform=get_transform(IMAGE_WIDTH, self.augment),
- target_transform=get_target_transform(
- self.word_pieces, max_len=MAX_WORD_PIECE_LENGTH
- ),
+ transform=self.transform,
+ target_transform=self.target_transform,
)
self.data_train, self.data_val = split_dataset(
@@ -118,10 +110,8 @@ class IAMLines(BaseDataModule):
self.data_test = BaseDataset(
x_test,
y_test,
- transform=get_transform(IMAGE_WIDTH),
- target_transform=get_target_transform(
- self.word_pieces, max_len=MAX_WORD_PIECE_LENGTH
- ),
+ transform=self.test_transform,
+ target_transform=self.target_transform,
)
if stage is None:
@@ -147,6 +137,8 @@ class IAMLines(BaseDataModule):
x, y = next(iter(self.train_dataloader()))
xt, yt = next(iter(self.test_dataloader()))
+ x = x[0] if isinstance(x, list) else x
+ xt = xt[0] if isinstance(xt, list) else xt
data = (
"Train/val/test sizes: "
f"{len(self.data_train)}, "
@@ -217,51 +209,8 @@ def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, Li
return crops, labels
-def get_transform(image_width: int, augment: bool = False) -> T.Compose:
- """Augment with brigthness, rotation, slant, translation, scale, and noise."""
-
- def embed_crop(
- crop: Image, augment: bool = augment, image_width: int = image_width
- ) -> Image:
- # Crop is PIL.Image of dtype="L" (so value range is [0, 255])
- image = Image.new("L", (image_width, IMAGE_HEIGHT))
-
- # Resize crop.
- crop_width, crop_height = crop.size
- new_crop_height = IMAGE_HEIGHT
- new_crop_width = int(new_crop_height * (crop_width / crop_height))
-
- if augment:
- # Add random stretching
- new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
- new_crop_width = min(new_crop_width, image_width)
- crop_resized = crop.resize(
- (new_crop_width, new_crop_height), resample=Image.BILINEAR
- )
-
- # Embed in image
- x = min(28, image_width - new_crop_width)
- y = IMAGE_HEIGHT - new_crop_height
- image.paste(crop_resized, (x, y))
-
- return image
-
- transfroms_list = [T.Lambda(embed_crop)]
-
- if augment:
- transfroms_list += [
- T.ColorJitter(brightness=(0.8, 1.6)),
- T.RandomAffine(
- degrees=1,
- shear=(-30, 20),
- interpolation=InterpolationMode.BILINEAR,
- fill=0,
- ),
- ]
- transfroms_list.append(T.ToTensor())
- return T.Compose(transfroms_list)
-
-
def generate_iam_lines() -> None:
"""Displays Iam Lines dataset statistics."""
- load_and_print_info(IAMLines)
+ transform = load_transform_from_file("transform/iam_lines.yaml")
+ test_transform = load_transform_from_file("test_transform/iam_lines.yaml")
+ load_and_print_info(IAMLines(transform=transform, test_transform=test_transform))
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 254c7f5..26674e0 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -8,7 +8,6 @@ from loguru import logger as log
import numpy as np
from PIL import Image, ImageOps
import torchvision.transforms as T
-from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
@@ -17,9 +16,9 @@ from text_recognizer.data.base_dataset import (
convert_strings_to_labels,
split_dataset,
)
-from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
-from text_recognizer.data.transforms import WordPiece
+from text_recognizer.data.mappings.emnist_mapping import EmnistMapping
+from text_recognizer.data.transforms.load_transform import load_transform_from_file
PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs"
@@ -38,11 +37,6 @@ MAX_WORD_PIECE_LENGTH = 451
class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
- word_pieces: bool = attr.ib(default=False)
- augment: bool = attr.ib(default=True)
- train_fraction: float = attr.ib(default=0.8)
- resize: Optional[Tuple[int, int]] = attr.ib(default=None)
-
# Placeholders
dims: Tuple[int, int, int] = attr.ib(
init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
@@ -82,7 +76,7 @@ class IAMParagraphs(BaseDataModule):
"""Loads the data for training/testing."""
def _load_dataset(
- split: str, augment: bool, resize: Optional[Tuple[int, int]]
+ split: str, transform: T.Compose, target_transform: T.Compose
) -> BaseDataset:
crops, labels = _load_processed_crops_and_labels(split)
data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
@@ -92,12 +86,7 @@ class IAMParagraphs(BaseDataModule):
length=self.output_dims[0],
)
return BaseDataset(
- data,
- targets,
- transform=get_transform(
- image_shape=self.dims[1:], augment=augment, resize=resize
- ),
- target_transform=get_target_transform(self.word_pieces),
+ data, targets, transform=transform, target_transform=target_transform,
)
log.info(f"Loading IAM paragraph regions and lines for {stage}...")
@@ -105,7 +94,9 @@ class IAMParagraphs(BaseDataModule):
if stage == "fit" or stage is None:
data_train = _load_dataset(
- split="train", augment=self.augment, resize=self.resize
+ split="train",
+ transform=self.transform,
+ target_transform=self.target_transform,
)
self.data_train, self.data_val = split_dataset(
dataset=data_train, fraction=self.train_fraction, seed=SEED
@@ -113,7 +104,9 @@ class IAMParagraphs(BaseDataModule):
if stage == "test" or stage is None:
self.data_test = _load_dataset(
- split="test", augment=False, resize=self.resize
+ split="test",
+ transform=self.test_transform,
+ target_transform=self.target_transform,
)
def __repr__(self) -> str:
@@ -130,6 +123,8 @@ class IAMParagraphs(BaseDataModule):
x, y = next(iter(self.train_dataloader()))
xt, yt = next(iter(self.test_dataloader()))
+ x = x[0] if isinstance(x, list) else x
+ xt = xt[0] if isinstance(xt, list) else xt
data = (
"Train/val/test sizes: "
f"{len(self.data_train)}, "
@@ -274,39 +269,6 @@ def _load_processed_crops_and_labels(
return ordered_crops, ordered_labels
-def get_transform(
- image_shape: Tuple[int, int], augment: bool, resize: Optional[Tuple[int, int]]
-) -> T.Compose:
- """Get transformations for images."""
- if augment:
- transforms_list = [
- T.RandomCrop(
- size=image_shape,
- padding=None,
- pad_if_needed=True,
- fill=0,
- padding_mode="constant",
- ),
- T.ColorJitter(brightness=(0.8, 1.6)),
- T.RandomAffine(
- degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
- ),
- ]
- else:
- transforms_list = [T.CenterCrop(image_shape)]
- if resize is not None:
- transforms_list.append(T.Resize(resize, T.InterpolationMode.BILINEAR))
- transforms_list.append(T.ToTensor())
- return T.Compose(transforms_list)
-
-
-def get_target_transform(
- word_pieces: bool, max_len: int = MAX_WORD_PIECE_LENGTH
-) -> Optional[T.Compose]:
- """Transform emnist characters to word pieces."""
- return T.Compose([WordPiece(max_len=max_len)]) if word_pieces else None
-
-
def _labels_filename(split: str) -> Path:
"""Return filename of processed labels."""
return PROCESSED_DATA_DIRNAME / split / "_labels.json"
@@ -324,4 +286,8 @@ def _num_lines(label: str) -> int:
def create_iam_paragraphs() -> None:
"""Loads and displays dataset statistics."""
- load_and_print_info(IAMParagraphs)
+ transform = load_transform_from_file("transform/paragraphs.yaml")
+ test_transform = load_transform_from_file("test_transform/paragraphs_test.yaml")
+ load_and_print_info(
+ IAMParagraphs(transform=transform, test_transform=test_transform)
+ )
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index f253427..5718747 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -12,7 +12,6 @@ from text_recognizer.data.base_dataset import (
BaseDataset,
convert_strings_to_labels,
)
-from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data.iam_lines import (
line_crops_and_labels,
@@ -21,13 +20,13 @@ from text_recognizer.data.iam_lines import (
)
from text_recognizer.data.iam_paragraphs import (
get_dataset_properties,
- get_target_transform,
- get_transform,
IAMParagraphs,
IMAGE_SCALE_FACTOR,
NEW_LINE_TOKEN,
resize_image,
)
+from text_recognizer.data.mappings.emnist_mapping import EmnistMapping
+from text_recognizer.data.transforms.load_transform import load_transform_from_file
PROCESSED_DATA_DIRNAME = (
@@ -83,10 +82,8 @@ class IAMSyntheticParagraphs(IAMParagraphs):
self.data_train = BaseDataset(
data,
targets,
- transform=get_transform(
- image_shape=self.dims[1:], augment=self.augment, resize=self.resize
- ),
- target_transform=get_target_transform(self.word_pieces),
+ transform=self.transform,
+ target_transform=self.target_transforms,
)
def __repr__(self) -> str:
@@ -101,6 +98,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
return basic
x, y = next(iter(self.train_dataloader()))
+ x = x[0] if isinstance(x, list) else x
data = (
f"Train/val/test sizes: {len(self.data_train)}, 0, 0\n"
f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
@@ -220,4 +218,8 @@ def generate_random_batches(
def create_synthetic_iam_paragraphs() -> None:
"""Creates and prints IAM Synthetic Paragraphs dataset."""
- load_and_print_info(IAMSyntheticParagraphs)
+ transform = load_transform_from_file("transform/paragraphs.yaml")
+ test_transform = load_transform_from_file("test_transform/paragraphs.yaml")
+ load_and_print_info(
+ IAMSyntheticParagraphs(transform=transform, test_transform=test_transform)
+ )