From 30e3ae483c846418b04ed48f014a4af2cf9a0771 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 10 Oct 2021 18:03:11 +0200 Subject: Update transforms in datamodule/set --- text_recognizer/data/base_data_module.py | 8 ++- text_recognizer/data/base_dataset.py | 22 +++++-- text_recognizer/data/emnist.py | 16 ++--- text_recognizer/data/emnist_lines.py | 55 +++++------------ text_recognizer/data/iam.py | 2 +- text_recognizer/data/iam_extended_paragraphs.py | 29 +++++---- text_recognizer/data/iam_lines.py | 75 ++++-------------------- text_recognizer/data/iam_paragraphs.py | 68 ++++++--------------- text_recognizer/data/iam_synthetic_paragraphs.py | 18 +++--- 9 files changed, 104 insertions(+), 189 deletions(-) (limited to 'text_recognizer') diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index ee70176..3add837 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -1,13 +1,13 @@ """Base lightning DataModule class.""" from pathlib import Path -from typing import Dict, Optional, Tuple, Type, TypeVar +from typing import Callable, Dict, Optional, Tuple, Type, TypeVar import attr from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader from text_recognizer.data.base_dataset import BaseDataset -from text_recognizer.data.base_mapping import AbstractMapping +from text_recognizer.data.mappings.base_mapping import AbstractMapping T = TypeVar("T") @@ -29,6 +29,10 @@ class BaseDataModule(LightningDataModule): super().__init__() mapping: Type[AbstractMapping] = attr.ib() + transform: Optional[Callable] = attr.ib(default=None) + test_transform: Optional[Callable] = attr.ib(default=None) + target_transform: Optional[Callable] = attr.ib(default=None) + train_fraction: float = attr.ib(default=0.8) batch_size: int = attr.ib(default=16) num_workers: int = attr.ib(default=0) pin_memory: bool = attr.ib(default=True) diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index e08130d..b9567c7 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -6,6 +6,8 @@ import torch from torch import Tensor from torch.utils.data import Dataset +from text_recognizer.data.transforms.load_transform import load_transform_from_file + @attr.s class BaseDataset(Dataset): @@ -21,8 +23,8 @@ class BaseDataset(Dataset): data: Union[Sequence, Tensor] = attr.ib() targets: Union[Sequence, Tensor] = attr.ib() - transform: Optional[Callable] = attr.ib(default=None) - target_transform: Optional[Callable] = attr.ib(default=None) + transform: Union[Optional[Callable], str] = attr.ib(default=None) + target_transform: Union[Optional[Callable], str] = attr.ib(default=None) def __attrs_pre_init__(self) -> None: """Pre init constructor.""" @@ -32,19 +34,31 @@ class BaseDataset(Dataset): """Post init constructor.""" if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") + self.transform = self._load_transform(self.transform) + self.target_transform = self._load_transform(self.target_transform) + + @staticmethod + def _load_transform( + transform: Union[Optional[Callable], str] + ) -> Optional[Callable]: + if isinstance(transform, str): + return load_transform_from_file(transform) + return transform def __len__(self) -> int: """Return the length of the dataset.""" return len(self.data) - def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: + def __getitem__( + self, index: int + ) -> Tuple[Union[Tensor, Tuple[Tensor, Tensor]], Tensor]: """Return a datum and its target, after processing by transforms. Args: index (int): Index of a datum in the dataset. Returns: - Tuple[Tensor, Tensor]: Datum and target pair. + Tuple[Union[Tensor, Tuple[Tensor, Tensor]], Tensor]: Datum and target pair. """ datum, target = self.data[index], self.targets[index] diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 9ec6efe..e2bc5b9 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -3,7 +3,7 @@ import json import os from pathlib import Path import shutil -from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple +from typing import Dict, List, Optional, Sequence, Set, Tuple import zipfile import attr @@ -11,14 +11,14 @@ import h5py from loguru import logger as log import numpy as np import toml -import torchvision.transforms as T from text_recognizer.data.base_data_module import ( BaseDataModule, load_and_print_info, ) from text_recognizer.data.base_dataset import BaseDataset, split_dataset -from text_recognizer.data.download_utils import download_dataset +from text_recognizer.data.utils.download_utils import download_dataset +from text_recognizer.data.transforms.load_transform import load_transform_from_file SEED = 4711 @@ -30,7 +30,9 @@ METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist" PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist" PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" -ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" +ESSENTIALS_FILENAME = ( + Path(__file__).parents[0].resolve() / "mappings" / "emnist_essentials.json" +) @attr.s(auto_attribs=True) @@ -46,9 +48,6 @@ class EMNIST(BaseDataModule): EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ - train_fraction: float = attr.ib(default=0.8) - transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()])) - def __attrs_post_init__(self) -> None: """Post init configuration.""" self.dims = (1, *self.mapping.input_size) @@ -226,4 +225,5 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: def download_emnist() -> None: """Download dataset from internet, if it does not exists, and displays info.""" - load_and_print_info(EMNIST) + transform = load_transform_from_file("transform/default.yaml") + load_and_print_info(EMNIST(transform=transform, test_transfrom=transform)) diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 3ff8a54..1a64931 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -1,7 +1,7 @@ """Dataset of generated text from EMNIST characters.""" from collections import defaultdict from pathlib import Path -from typing import Callable, List, Tuple +from typing import DefaultDict, List, Tuple import attr import h5py @@ -9,8 +9,7 @@ from loguru import logger as log import numpy as np import torch from torch import Tensor -from torchvision import transforms -from torchvision.transforms.functional import InterpolationMode +import torchvision.transforms as T from text_recognizer.data.base_data_module import ( BaseDataModule, @@ -18,12 +17,13 @@ from text_recognizer.data.base_data_module import ( ) from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.emnist import EMNIST -from text_recognizer.data.sentence_generator import SentenceGenerator +from text_recognizer.data.utils.sentence_generator import SentenceGenerator +from text_recognizer.data.transforms.load_transform import load_transform_from_file DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines" ESSENTIALS_FILENAME = ( - Path(__file__).parents[0].resolve() / "emnist_lines_essentials.json" + Path(__file__).parents[0].resolve() / "mappings" / "emnist_lines_essentials.json" ) SEED = 4711 @@ -37,7 +37,6 @@ MAX_OUTPUT_LENGTH = 89 # Same as IAMLines class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST.""" - augment: bool = attr.ib(default=True) max_length: int = attr.ib(default=128) min_overlap: float = attr.ib(default=0.0) max_overlap: float = attr.ib(default=0.33) @@ -98,21 +97,15 @@ class EMNISTLines(BaseDataModule): x_val = f["x_val"][:] y_val = torch.LongTensor(f["y_val"][:]) - self.data_train = BaseDataset( - x_train, y_train, transform=_get_transform(augment=self.augment) - ) - self.data_val = BaseDataset( - x_val, y_val, transform=_get_transform(augment=self.augment) - ) + self.data_train = BaseDataset(x_train, y_train, transform=self.transform) + self.data_val = BaseDataset(x_val, y_val, transform=self.transform) if stage == "test" or stage is None: with h5py.File(self.data_filename, "r") as f: x_test = f["x_test"][:] y_test = torch.LongTensor(f["y_test"][:]) - self.data_test = BaseDataset( - x_test, y_test, transform=_get_transform(augment=False) - ) + self.data_test = BaseDataset(x_test, y_test, transform=self.test_transform) def __repr__(self) -> str: """Return str about dataset.""" @@ -129,6 +122,7 @@ class EMNISTLines(BaseDataModule): return basic x, y = next(iter(self.train_dataloader())) + x = x[0] if isinstance(x, list) else x data = ( "Train/val/test sizes: " f"{len(self.data_train)}, " @@ -184,7 +178,7 @@ class EMNISTLines(BaseDataModule): def _get_samples_by_char( samples: np.ndarray, labels: np.ndarray, mapping: List -) -> defaultdict: +) -> DefaultDict: samples_by_char = defaultdict(list) for sample, label in zip(samples, labels): samples_by_char[mapping[label]].append(sample) @@ -192,7 +186,7 @@ def _get_samples_by_char( def _select_letter_samples_for_string( - string: str, samples_by_char: defaultdict + string: str, samples_by_char: DefaultDict ) -> List[Tensor]: null_image = torch.zeros((28, 28), dtype=torch.uint8) sample_image_by_char = {} @@ -207,7 +201,7 @@ def _select_letter_samples_for_string( def _construct_image_from_string( string: str, - samples_by_char: defaultdict, + samples_by_char: DefaultDict, min_overlap: float, max_overlap: float, width: int, @@ -226,7 +220,7 @@ def _construct_image_from_string( def _create_dataset_of_images( num_samples: int, - samples_by_char: defaultdict, + samples_by_char: DefaultDict, sentence_generator: SentenceGenerator, min_overlap: float, max_overlap: float, @@ -246,25 +240,8 @@ def _create_dataset_of_images( return images, labels -def _get_transform(augment: bool = False) -> Callable: - if not augment: - return transforms.Compose([transforms.ToTensor()]) - return transforms.Compose( - [ - transforms.ToTensor(), - transforms.ColorJitter(brightness=(0.5, 1.0)), - transforms.RandomAffine( - degrees=3, - translate=(0.0, 0.05), - scale=(0.4, 1.1), - shear=(-40, 50), - interpolation=InterpolationMode.BILINEAR, - fill=0, - ), - ] - ) - - def generate_emnist_lines() -> None: """Generates a synthetic handwritten dataset and displays info.""" - load_and_print_info(EMNISTLines) + transform = load_transform_from_file("transform/emnist_lines.yaml") + test_transform = load_transform_from_file("test_transform/default.yaml") + load_and_print_info(EMNISTLines(transform=transform, test_transform=test_transform)) diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index 263bf8e..766f3e0 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -15,7 +15,7 @@ from loguru import logger as log import toml from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info -from text_recognizer.data.download_utils import download_dataset +from text_recognizer.data.utils.download_utils import download_dataset RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "iam" diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 8b3a46c..87b8ef1 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -1,21 +1,17 @@ """IAM original and sythetic dataset class.""" import attr -from typing import Optional, Tuple from torch.utils.data import ConcatDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs +from text_recognizer.data.transforms.load_transform import load_transform_from_file @attr.s(auto_attribs=True, repr=False) class IAMExtendedParagraphs(BaseDataModule): - - augment: bool = attr.ib(default=True) - train_fraction: float = attr.ib(default=0.8) - word_pieces: bool = attr.ib(default=False) - resize: Optional[Tuple[int, int]] = attr.ib(default=None) + """A dataset with synthetic and real handwritten paragraph.""" def __attrs_post_init__(self) -> None: self.iam_paragraphs = IAMParagraphs( @@ -23,18 +19,18 @@ class IAMExtendedParagraphs(BaseDataModule): batch_size=self.batch_size, num_workers=self.num_workers, train_fraction=self.train_fraction, - augment=self.augment, - word_pieces=self.word_pieces, - resize=self.resize, + transform=self.transform, + test_transform=self.test_transform, + target_transform=self.target_transform, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( mapping=self.mapping, batch_size=self.batch_size, num_workers=self.num_workers, train_fraction=self.train_fraction, - augment=self.augment, - word_pieces=self.word_pieces, - resize=self.resize, + transform=self.transform, + test_transform=self.test_transform, + target_transform=self.target_transform, ) self.dims = self.iam_paragraphs.dims @@ -69,6 +65,8 @@ class IAMExtendedParagraphs(BaseDataModule): x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) + x = x[0] if isinstance(x, list) else x + xt = xt[0] if isinstance(xt, list) else xt data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" @@ -80,4 +78,9 @@ class IAMExtendedParagraphs(BaseDataModule): def show_dataset_info() -> None: - load_and_print_info(IAMExtendedParagraphs) + """Displays Iam extended dataset information.""" + transform = load_transform_from_file("transform/paragraphs.yaml") + test_transform = load_transform_from_file("test_transform/paragraphs_test.yaml") + load_and_print_info( + IAMExtendedParagraphs(transform=transform, test_transform=test_transform) + ) diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index 7a063c1..efd1cde 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -5,7 +5,6 @@ dataset. """ import json from pathlib import Path -import random from typing import List, Sequence, Tuple import attr @@ -13,19 +12,17 @@ from loguru import logger as log import numpy as np from PIL import Image, ImageFile, ImageOps from torch import Tensor -import torchvision.transforms as T -from torchvision.transforms.functional import InterpolationMode -from text_recognizer.data import image_utils from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.base_dataset import ( BaseDataset, convert_strings_to_labels, split_dataset, ) -from text_recognizer.data.emnist_mapping import EmnistMapping from text_recognizer.data.iam import IAM -from text_recognizer.data.iam_paragraphs import get_target_transform +from text_recognizer.data.mappings.emnist_mapping import EmnistMapping +from text_recognizer.data.utils import image_utils +from text_recognizer.data.transforms.load_transform import load_transform_from_file ImageFile.LOAD_TRUNCATED_IMAGES = True @@ -42,9 +39,6 @@ MAX_WORD_PIECE_LENGTH = 72 class IAMLines(BaseDataModule): """IAM handwritten lines dataset.""" - word_pieces: bool = attr.ib(default=False) - augment: bool = attr.ib(default=True) - train_fraction: float = attr.ib(default=0.8) dims: Tuple[int, int, int] = attr.ib( init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) ) @@ -94,10 +88,8 @@ class IAMLines(BaseDataModule): data_train = BaseDataset( x_train, y_train, - transform=get_transform(IMAGE_WIDTH, self.augment), - target_transform=get_target_transform( - self.word_pieces, max_len=MAX_WORD_PIECE_LENGTH - ), + transform=self.transform, + target_transform=self.target_transform, ) self.data_train, self.data_val = split_dataset( @@ -118,10 +110,8 @@ class IAMLines(BaseDataModule): self.data_test = BaseDataset( x_test, y_test, - transform=get_transform(IMAGE_WIDTH), - target_transform=get_target_transform( - self.word_pieces, max_len=MAX_WORD_PIECE_LENGTH - ), + transform=self.test_transform, + target_transform=self.target_transform, ) if stage is None: @@ -147,6 +137,8 @@ class IAMLines(BaseDataModule): x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) + x = x[0] if isinstance(x, list) else x + xt = xt[0] if isinstance(xt, list) else xt data = ( "Train/val/test sizes: " f"{len(self.data_train)}, " @@ -217,51 +209,8 @@ def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, Li return crops, labels -def get_transform(image_width: int, augment: bool = False) -> T.Compose: - """Augment with brigthness, rotation, slant, translation, scale, and noise.""" - - def embed_crop( - crop: Image, augment: bool = augment, image_width: int = image_width - ) -> Image: - # Crop is PIL.Image of dtype="L" (so value range is [0, 255]) - image = Image.new("L", (image_width, IMAGE_HEIGHT)) - - # Resize crop. - crop_width, crop_height = crop.size - new_crop_height = IMAGE_HEIGHT - new_crop_width = int(new_crop_height * (crop_width / crop_height)) - - if augment: - # Add random stretching - new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1)) - new_crop_width = min(new_crop_width, image_width) - crop_resized = crop.resize( - (new_crop_width, new_crop_height), resample=Image.BILINEAR - ) - - # Embed in image - x = min(28, image_width - new_crop_width) - y = IMAGE_HEIGHT - new_crop_height - image.paste(crop_resized, (x, y)) - - return image - - transfroms_list = [T.Lambda(embed_crop)] - - if augment: - transfroms_list += [ - T.ColorJitter(brightness=(0.8, 1.6)), - T.RandomAffine( - degrees=1, - shear=(-30, 20), - interpolation=InterpolationMode.BILINEAR, - fill=0, - ), - ] - transfroms_list.append(T.ToTensor()) - return T.Compose(transfroms_list) - - def generate_iam_lines() -> None: """Displays Iam Lines dataset statistics.""" - load_and_print_info(IAMLines) + transform = load_transform_from_file("transform/iam_lines.yaml") + test_transform = load_transform_from_file("test_transform/iam_lines.yaml") + load_and_print_info(IAMLines(transform=transform, test_transform=test_transform)) diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 254c7f5..26674e0 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -8,7 +8,6 @@ from loguru import logger as log import numpy as np from PIL import Image, ImageOps import torchvision.transforms as T -from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info @@ -17,9 +16,9 @@ from text_recognizer.data.base_dataset import ( convert_strings_to_labels, split_dataset, ) -from text_recognizer.data.emnist_mapping import EmnistMapping from text_recognizer.data.iam import IAM -from text_recognizer.data.transforms import WordPiece +from text_recognizer.data.mappings.emnist_mapping import EmnistMapping +from text_recognizer.data.transforms.load_transform import load_transform_from_file PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs" @@ -38,11 +37,6 @@ MAX_WORD_PIECE_LENGTH = 451 class IAMParagraphs(BaseDataModule): """IAM handwriting database paragraphs.""" - word_pieces: bool = attr.ib(default=False) - augment: bool = attr.ib(default=True) - train_fraction: float = attr.ib(default=0.8) - resize: Optional[Tuple[int, int]] = attr.ib(default=None) - # Placeholders dims: Tuple[int, int, int] = attr.ib( init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) @@ -82,7 +76,7 @@ class IAMParagraphs(BaseDataModule): """Loads the data for training/testing.""" def _load_dataset( - split: str, augment: bool, resize: Optional[Tuple[int, int]] + split: str, transform: T.Compose, target_transform: T.Compose ) -> BaseDataset: crops, labels = _load_processed_crops_and_labels(split) data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops] @@ -92,12 +86,7 @@ class IAMParagraphs(BaseDataModule): length=self.output_dims[0], ) return BaseDataset( - data, - targets, - transform=get_transform( - image_shape=self.dims[1:], augment=augment, resize=resize - ), - target_transform=get_target_transform(self.word_pieces), + data, targets, transform=transform, target_transform=target_transform, ) log.info(f"Loading IAM paragraph regions and lines for {stage}...") @@ -105,7 +94,9 @@ class IAMParagraphs(BaseDataModule): if stage == "fit" or stage is None: data_train = _load_dataset( - split="train", augment=self.augment, resize=self.resize + split="train", + transform=self.transform, + target_transform=self.target_transform, ) self.data_train, self.data_val = split_dataset( dataset=data_train, fraction=self.train_fraction, seed=SEED @@ -113,7 +104,9 @@ class IAMParagraphs(BaseDataModule): if stage == "test" or stage is None: self.data_test = _load_dataset( - split="test", augment=False, resize=self.resize + split="test", + transform=self.test_transform, + target_transform=self.target_transform, ) def __repr__(self) -> str: @@ -130,6 +123,8 @@ class IAMParagraphs(BaseDataModule): x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) + x = x[0] if isinstance(x, list) else x + xt = xt[0] if isinstance(xt, list) else xt data = ( "Train/val/test sizes: " f"{len(self.data_train)}, " @@ -274,39 +269,6 @@ def _load_processed_crops_and_labels( return ordered_crops, ordered_labels -def get_transform( - image_shape: Tuple[int, int], augment: bool, resize: Optional[Tuple[int, int]] -) -> T.Compose: - """Get transformations for images.""" - if augment: - transforms_list = [ - T.RandomCrop( - size=image_shape, - padding=None, - pad_if_needed=True, - fill=0, - padding_mode="constant", - ), - T.ColorJitter(brightness=(0.8, 1.6)), - T.RandomAffine( - degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, - ), - ] - else: - transforms_list = [T.CenterCrop(image_shape)] - if resize is not None: - transforms_list.append(T.Resize(resize, T.InterpolationMode.BILINEAR)) - transforms_list.append(T.ToTensor()) - return T.Compose(transforms_list) - - -def get_target_transform( - word_pieces: bool, max_len: int = MAX_WORD_PIECE_LENGTH -) -> Optional[T.Compose]: - """Transform emnist characters to word pieces.""" - return T.Compose([WordPiece(max_len=max_len)]) if word_pieces else None - - def _labels_filename(split: str) -> Path: """Return filename of processed labels.""" return PROCESSED_DATA_DIRNAME / split / "_labels.json" @@ -324,4 +286,8 @@ def _num_lines(label: str) -> int: def create_iam_paragraphs() -> None: """Loads and displays dataset statistics.""" - load_and_print_info(IAMParagraphs) + transform = load_transform_from_file("transform/paragraphs.yaml") + test_transform = load_transform_from_file("test_transform/paragraphs_test.yaml") + load_and_print_info( + IAMParagraphs(transform=transform, test_transform=test_transform) + ) diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index f253427..5718747 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -12,7 +12,6 @@ from text_recognizer.data.base_dataset import ( BaseDataset, convert_strings_to_labels, ) -from text_recognizer.data.emnist_mapping import EmnistMapping from text_recognizer.data.iam import IAM from text_recognizer.data.iam_lines import ( line_crops_and_labels, @@ -21,13 +20,13 @@ from text_recognizer.data.iam_lines import ( ) from text_recognizer.data.iam_paragraphs import ( get_dataset_properties, - get_target_transform, - get_transform, IAMParagraphs, IMAGE_SCALE_FACTOR, NEW_LINE_TOKEN, resize_image, ) +from text_recognizer.data.mappings.emnist_mapping import EmnistMapping +from text_recognizer.data.transforms.load_transform import load_transform_from_file PROCESSED_DATA_DIRNAME = ( @@ -83,10 +82,8 @@ class IAMSyntheticParagraphs(IAMParagraphs): self.data_train = BaseDataset( data, targets, - transform=get_transform( - image_shape=self.dims[1:], augment=self.augment, resize=self.resize - ), - target_transform=get_target_transform(self.word_pieces), + transform=self.transform, + target_transform=self.target_transforms, ) def __repr__(self) -> str: @@ -101,6 +98,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): return basic x, y = next(iter(self.train_dataloader())) + x = x[0] if isinstance(x, list) else x data = ( f"Train/val/test sizes: {len(self.data_train)}, 0, 0\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" @@ -220,4 +218,8 @@ def generate_random_batches( def create_synthetic_iam_paragraphs() -> None: """Creates and prints IAM Synthetic Paragraphs dataset.""" - load_and_print_info(IAMSyntheticParagraphs) + transform = load_transform_from_file("transform/paragraphs.yaml") + test_transform = load_transform_from_file("test_transform/paragraphs.yaml") + load_and_print_info( + IAMSyntheticParagraphs(transform=transform, test_transform=test_transform) + ) -- cgit v1.2.3-70-g09d2