From 27ff7d113108e9cc51ddc5ff13b648b9c75fa865 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 27 Sep 2022 00:08:04 +0200 Subject: Add metadata --- text_recognizer/data/emnist.py | 48 ++++++++++-------------- text_recognizer/data/emnist_lines.py | 44 +++++++++------------- text_recognizer/data/iam.py | 27 +++++-------- text_recognizer/data/iam_lines.py | 38 +++++++++---------- text_recognizer/data/iam_paragraphs.py | 38 +++++++------------ text_recognizer/data/iam_synthetic_paragraphs.py | 8 ++-- 6 files changed, 84 insertions(+), 119 deletions(-) (limited to 'text_recognizer/data') diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 72cc80a..9c5727f 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -15,27 +15,15 @@ from text_recognizer.data.base_data_module import BaseDataModule, load_and_print from text_recognizer.data.base_dataset import BaseDataset, split_dataset from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils.download_utils import download_dataset - -SEED = 4711 -NUM_SPECIAL_TOKENS = 4 -SAMPLE_TO_BALANCE = True - -RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist" -METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" -DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist" -PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist" -PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" -ESSENTIALS_FILENAME = ( - Path(__file__).parents[0].resolve() / "mappings" / "emnist_essentials.json" -) +from text_recognizer.metadata import emnist as metadata class EMNIST(BaseDataModule): """Lightning DataModule class for loading EMNIST dataset. 'The EMNIST dataset is a set of handwritten character digits derived from the NIST - Special Database 19 and converted to a 28x28 pixel image format and dataset structure - that directly matches the MNIST dataset.' + Special Database 19 and converted to a 28x28 pixel image format and dataset + structure that directly matches the MNIST dataset.' From https://www.nist.gov/itl/iad/image-group/emnist-dataset The data split we will use is @@ -48,13 +36,13 @@ class EMNIST(BaseDataModule): def prepare_data(self) -> None: """Downloads dataset if not present.""" - if not PROCESSED_DATA_FILENAME.exists(): + if not metadata.PROCESSED_DATA_FILENAME.exists(): download_and_process_emnist() def setup(self, stage: Optional[str] = None) -> None: """Loads the dataset specified by the stage.""" if stage == "fit" or stage is None: - with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: + with h5py.File(metadata.PROCESSED_DATA_FILENAME, "r") as f: self.x_train = f["x_train"][:] self.y_train = f["y_train"][:].squeeze().astype(int) @@ -62,11 +50,11 @@ class EMNIST(BaseDataModule): self.x_train, self.y_train, transform=self.transform ) self.data_train, self.data_val = split_dataset( - dataset_train, fraction=self.train_fraction, seed=SEED + dataset_train, fraction=self.train_fraction, seed=metadata.SEED ) if stage == "test" or stage is None: - with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: + with h5py.File(metadata.PROCESSED_DATA_FILENAME, "r") as f: self.x_test = f["x_test"][:] self.y_test = f["y_test"][:].squeeze().astype(int) self.data_test = BaseDataset( @@ -100,9 +88,9 @@ class EMNIST(BaseDataModule): def download_and_process_emnist() -> None: """Downloads and preprocesses EMNIST dataset.""" - metadata = toml.load(METADATA_FILENAME) - download_dataset(metadata, DL_DATA_DIRNAME) - _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) + metadata_ = toml.load(metadata.METADATA_FILENAME) + download_dataset(metadata_, metadata.DL_DATA_DIRNAME) + _process_raw_dataset(metadata_["filename"], metadata.DL_DATA_DIRNAME) def _process_raw_dataset(filename: str, dirname: Path) -> None: @@ -122,20 +110,22 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: .reshape(-1, 28, 28) .swapaxes(1, 2) ) - y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS + y_train = ( + data["dataset"]["train"][0, 0]["labels"][0, 0] + metadata.NUM_SPECIAL_TOKENS + ) x_test = ( data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) ) - y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS + y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + metadata.NUM_SPECIAL_TOKENS - if SAMPLE_TO_BALANCE: + if metadata.SAMPLE_TO_BALANCE: log.info("Balancing classes to reduce amount of data") x_train, y_train = _sample_to_balance(x_train, y_train) x_test, y_test = _sample_to_balance(x_test, y_test) log.info("Saving to HDF5 in a compressed format...") - PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: + metadata.PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + with h5py.File(metadata.PROCESSED_DATA_FILENAME, "w") as f: f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf") f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") @@ -146,7 +136,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: characters = _augment_emnist_characters(mapping.values()) essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} - with ESSENTIALS_FILENAME.open(mode="w") as f: + with metadata.ESSENTIALS_FILENAME.open(mode="w") as f: json.dump(essentials, f) log.info("Cleaning up...") @@ -156,7 +146,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Balances the dataset by taking the mean number of instances per class.""" - np.random.seed(SEED) + np.random.seed(metadata.SEED) num_to_sample = int(np.bincount(y.flatten()).mean()) all_sampled_indices = [] for label in np.unique(y.flatten()): diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index c36132e..63c9f22 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -1,12 +1,11 @@ """Dataset of generated text from EMNIST characters.""" from collections import defaultdict from pathlib import Path -from typing import Callable, DefaultDict, List, Optional, Tuple, Type +from typing import Callable, DefaultDict, List, Optional, Tuple import h5py import numpy as np import torch -import torchvision.transforms as T from loguru import logger as log from torch import Tensor @@ -16,17 +15,7 @@ from text_recognizer.data.emnist import EMNIST from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils.sentence_generator import SentenceGenerator - -DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines" -ESSENTIALS_FILENAME = ( - Path(__file__).parents[0].resolve() / "mappings" / "emnist_lines_essentials.json" -) - -SEED = 4711 -IMAGE_HEIGHT = 56 -IMAGE_WIDTH = 1024 -IMAGE_X_PADDING = 28 -MAX_OUTPUT_LENGTH = 89 # Same as IAMLines +from text_recognizer.metadata import emnist_lines as metadata class EMNISTLines(BaseDataModule): @@ -70,25 +59,25 @@ class EMNISTLines(BaseDataModule): self.emnist = EMNIST() max_width = ( int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) - + IMAGE_X_PADDING + + metadata.IMAGE_X_PADDING ) - if max_width >= IMAGE_WIDTH: + if max_width >= metadata.IMAGE_WIDTH: raise ValueError( - f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}" + f"max_width {max_width} greater than IMAGE_WIDTH {metadata.IMAGE_WIDTH}" ) - self.dims = (self.emnist.dims[0], IMAGE_HEIGHT, IMAGE_WIDTH) + self.dims = (self.emnist.dims[0], metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH) - if self.max_length >= MAX_OUTPUT_LENGTH: + if self.max_length >= metadata.MAX_OUTPUT_LENGTH: raise ValueError("max_length greater than MAX_OUTPUT_LENGTH") - self.output_dims = (MAX_OUTPUT_LENGTH, 1) + self.output_dims = (metadata.MAX_OUTPUT_LENGTH, 1) @property def data_filename(self) -> Path: """Return name of dataset.""" - return DATA_DIRNAME / ( + return metadata.DATA_DIRNAME / ( f"ml_{self.max_length}_" f"o{self.min_overlap:f}_{self.max_overlap:f}_" f"ntr{self.num_train}_" @@ -100,7 +89,7 @@ class EMNISTLines(BaseDataModule): """Prepare the dataset.""" if self.data_filename.exists(): return - np.random.seed(SEED) + np.random.seed(metadata.SEED) self._generate_data("train") self._generate_data("val") self._generate_data("test") @@ -146,7 +135,8 @@ class EMNISTLines(BaseDataModule): f"{len(self.data_train)}, " f"{len(self.data_val)}, " f"{len(self.data_test)}\n" - f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" + "Batch x stats: " + f"{(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" ) return basic + data @@ -177,7 +167,7 @@ class EMNISTLines(BaseDataModule): ) num = self.num_test - DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + metadata.PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(self.data_filename, "a") as f: x, y = _create_dataset_of_images( num, @@ -188,7 +178,7 @@ class EMNISTLines(BaseDataModule): self.dims, ) y = convert_strings_to_labels( - y, self.mapping.inverse_mapping, length=MAX_OUTPUT_LENGTH + y, self.mapping.inverse_mapping, length=metadata.MAX_OUTPUT_LENGTH ) f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") @@ -229,7 +219,7 @@ def _construct_image_from_string( H, W = sampled_images[0].shape next_overlap_width = W - int(overlap * W) concatenated_image = torch.zeros((H, width), dtype=torch.uint8) - x = IMAGE_X_PADDING + x = metadata.IMAGE_X_PADDING for image in sampled_images: concatenated_image[:, x : (x + W)] += image x += next_overlap_width @@ -244,7 +234,7 @@ def _create_dataset_of_images( max_overlap: float, dims: Tuple, ) -> Tuple[Tensor, Tensor]: - images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2])) + images = torch.zeros((num_samples, metadata.IMAGE_HEIGHT, dims[2])) labels = [] for n in range(num_samples): label = sentence_generator.generate() @@ -252,7 +242,7 @@ def _create_dataset_of_images( label, samples_by_char, min_overlap, max_overlap, dims[-1] ) height = crop.shape[0] - y = (IMAGE_HEIGHT - height) // 2 + y = (metadata.IMAGE_HEIGHT - height) // 2 images[n, y : (y + height), :] = crop labels.append(label) return images, labels diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index e3baf88..c20b50b 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -15,14 +15,7 @@ from loguru import logger as log from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.utils.download_utils import download_dataset - -RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "iam" -METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" -DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "iam" -EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb" - -DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. -LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates. +from text_recognizer.metadata import iam as metadata class IAM(BaseDataModule): @@ -44,24 +37,24 @@ class IAM(BaseDataModule): def __init__(self) -> None: super().__init__() - self.metadata: Dict = toml.load(METADATA_FILENAME) + self.metadata: Dict = toml.load(metadata.METADATA_FILENAME) def prepare_data(self) -> None: """Prepares the IAM dataset.""" if self.xml_filenames: return - filename = download_dataset(self.metadata, DL_DATA_DIRNAME) - _extract_raw_dataset(filename, DL_DATA_DIRNAME) + filename = download_dataset(self.metadata, metadata.DL_DATA_DIRNAME) + _extract_raw_dataset(filename, metadata.DL_DATA_DIRNAME) @property def xml_filenames(self) -> List[Path]: """Returns the xml filenames.""" - return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) + return list((metadata.EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) @property def form_filenames(self) -> List[Path]: """Returns the form filenames.""" - return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) + return list((metadata.EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) @property def form_filenames_by_id(self) -> Dict[str, Path]: @@ -133,10 +126,10 @@ def _get_line_region_from_xml_file(xml_line: Any) -> Dict[str, int]: x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] return { - "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, - "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + "x1": min(x1s) // metadata.DOWNSAMPLE_FACTOR - metadata.LINE_REGION_PADDING, + "y1": min(y1s) // metadata.DOWNSAMPLE_FACTOR - metadata.LINE_REGION_PADDING, + "x2": max(x2s) // metadata.DOWNSAMPLE_FACTOR + metadata.LINE_REGION_PADDING, + "y2": max(y2s) // metadata.DOWNSAMPLE_FACTOR + metadata.LINE_REGION_PADDING, } diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index a55ff1c..3bb189c 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -22,16 +22,10 @@ from text_recognizer.data.iam import IAM from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils import image_utils +from text_recognizer.metadata import iam_lines as metadata ImageFile.LOAD_TRUNCATED_IMAGES = True -SEED = 4711 -PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_lines" -IMAGE_HEIGHT = 56 -IMAGE_WIDTH = 1024 -MAX_LABEL_LENGTH = 89 -MAX_WORD_PIECE_LENGTH = 72 - class IAMLines(BaseDataModule): """IAM handwritten lines dataset.""" @@ -57,12 +51,12 @@ class IAMLines(BaseDataModule): num_workers, pin_memory, ) - self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) - self.output_dims = (MAX_LABEL_LENGTH, 1) + self.dims = (1, metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH) + self.output_dims = (metadata.MAX_LABEL_LENGTH, 1) def prepare_data(self) -> None: """Creates the IAM lines dataset if not existing.""" - if PROCESSED_DATA_DIRNAME.exists(): + if metadata.PROCESSED_DATA_DIRNAME.exists(): return log.info("Cropping IAM lines regions...") @@ -76,24 +70,30 @@ class IAMLines(BaseDataModule): log.info("Saving images, labels, and statistics...") save_images_and_labels( - crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME + crops_train, labels_train, "train", metadata.PROCESSED_DATA_DIRNAME + ) + save_images_and_labels( + crops_test, labels_test, "test", metadata.PROCESSED_DATA_DIRNAME ) - save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME) - with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="w") as f: + with (metadata.PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open( + mode="w" + ) as f: f.write(str(aspect_ratios.max())) def setup(self, stage: str = None) -> None: """Load data for training/testing.""" - with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="r") as f: + with (metadata.PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open( + mode="r" + ) as f: max_aspect_ratio = float(f.read()) - image_width = int(IMAGE_HEIGHT * max_aspect_ratio) - if image_width >= IMAGE_WIDTH: + image_width = int(metadata.IMAGE_HEIGHT * max_aspect_ratio) + if image_width >= metadata.IMAGE_WIDTH: raise ValueError("image_width equal or greater than IMAGE_WIDTH") if stage == "fit" or stage is None: x_train, labels_train = load_line_crops_and_labels( - "train", PROCESSED_DATA_DIRNAME + "train", metadata.PROCESSED_DATA_DIRNAME ) if self.output_dims[0] < max([len(labels) for labels in labels_train]) + 2: raise ValueError("Target length longer than max output length.") @@ -109,12 +109,12 @@ class IAMLines(BaseDataModule): ) self.data_train, self.data_val = split_dataset( - dataset=data_train, fraction=self.train_fraction, seed=SEED + dataset=data_train, fraction=self.train_fraction, seed=metadata.SEED ) if stage == "test" or stage is None: x_test, labels_test = load_line_crops_and_labels( - "test", PROCESSED_DATA_DIRNAME + "test", metadata.PROCESSED_DATA_DIRNAME ) if self.output_dims[0] < max([len(labels) for labels in labels_test]) + 2: diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index c7d5229..eec1b1f 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -18,17 +18,7 @@ from text_recognizer.data.base_dataset import ( from text_recognizer.data.iam import IAM from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file - -PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs" - -NEW_LINE_TOKEN = "\n" - -SEED = 4711 -IMAGE_SCALE_FACTOR = 2 -IMAGE_HEIGHT = 1152 // IMAGE_SCALE_FACTOR -IMAGE_WIDTH = 1280 // IMAGE_SCALE_FACTOR -MAX_LABEL_LENGTH = 682 -MAX_WORD_PIECE_LENGTH = 451 +from text_recognizer.metadata import iam_paragraphs as metadata class IAMParagraphs(BaseDataModule): @@ -55,17 +45,17 @@ class IAMParagraphs(BaseDataModule): num_workers, pin_memory, ) - self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) - self.output_dims = (MAX_LABEL_LENGTH, 1) + self.dims = (1, metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH) + self.output_dims = (metadata.MAX_LABEL_LENGTH, 1) def prepare_data(self) -> None: """Create data for training/testing.""" - if PROCESSED_DATA_DIRNAME.exists(): + if metadata.PROCESSED_DATA_DIRNAME.exists(): return log.info("Cropping IAM paragraph regions and saving them along with labels...") - iam = IAM(mapping=EmnistMapping(extra_symbols={NEW_LINE_TOKEN})) + iam = IAM(mapping=EmnistMapping(extra_symbols={metadata.NEW_LINE_TOKEN})) iam.prepare_data() properties = {} @@ -84,7 +74,7 @@ class IAMParagraphs(BaseDataModule): } ) - with (PROCESSED_DATA_DIRNAME / "_properties.json").open("w") as f: + with (metadata.PROCESSED_DATA_DIRNAME / "_properties.json").open("w") as f: json.dump(properties, f, indent=4) def setup(self, stage: str = None) -> None: @@ -94,7 +84,7 @@ class IAMParagraphs(BaseDataModule): split: str, transform: T.Compose, target_transform: T.Compose ) -> BaseDataset: crops, labels = _load_processed_crops_and_labels(split) - data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops] + data = [resize_image(crop, metadata.IMAGE_SCALE_FACTOR) for crop in crops] targets = convert_strings_to_labels( strings=labels, mapping=self.mapping.inverse_mapping, @@ -117,7 +107,7 @@ class IAMParagraphs(BaseDataModule): target_transform=self.target_transform, ) self.data_train, self.data_val = split_dataset( - dataset=data_train, fraction=self.train_fraction, seed=SEED + dataset=data_train, fraction=self.train_fraction, seed=metadata.SEED ) if stage == "test" or stage is None: @@ -162,7 +152,7 @@ class IAMParagraphs(BaseDataModule): def get_dataset_properties() -> Dict: """Return properties describing the overall dataset.""" - with (PROCESSED_DATA_DIRNAME / "_properties.json").open("r") as f: + with (metadata.PROCESSED_DATA_DIRNAME / "_properties.json").open("r") as f: properties = json.load(f) def _get_property_values(key: str) -> List: @@ -193,7 +183,7 @@ def _validate_data_dims( """Validates input and output dimensions against the properties of the dataset.""" properties = get_dataset_properties() - max_image_shape = properties["crop_shape"]["max"] / IMAGE_SCALE_FACTOR + max_image_shape = properties["crop_shape"]["max"] / metadata.IMAGE_SCALE_FACTOR if ( input_dims is not None and input_dims[1] < max_image_shape[0] @@ -246,7 +236,7 @@ def _get_paragraph_crops_and_labels( lines = iam.line_strings_by_id[id_] crops[id_] = image.crop(paragraph_box) - labels[id_] = NEW_LINE_TOKEN.join(lines) + labels[id_] = metadata.NEW_LINE_TOKEN.join(lines) if len(crops) != len(labels): raise ValueError(f"Crops ({len(crops)}) does not match labels ({len(labels)})") @@ -258,7 +248,7 @@ def _save_crops_and_labels( crops: Dict[str, Image.Image], labels: Dict[str, str], split: str ) -> None: """Save crops, labels, and shapes of crops of a split.""" - (PROCESSED_DATA_DIRNAME / split).mkdir(parents=True, exist_ok=True) + (metadata.PROCESSED_DATA_DIRNAME / split).mkdir(parents=True, exist_ok=True) with _labels_filename(split).open("w") as f: json.dump(labels, f, indent=4) @@ -289,12 +279,12 @@ def _load_processed_crops_and_labels( def _labels_filename(split: str) -> Path: """Return filename of processed labels.""" - return PROCESSED_DATA_DIRNAME / split / "_labels.json" + return metadata.PROCESSED_DATA_DIRNAME / split / "_labels.json" def _crop_filename(id: str, split: str) -> Path: """Return filename of processed crop.""" - return PROCESSED_DATA_DIRNAME / split / f"{id}.png" + return metadata.PROCESSED_DATA_DIRNAME / split / f"{id}.png" def _num_lines(label: str) -> int: diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 5e66499..52ed398 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -6,7 +6,7 @@ import numpy as np from loguru import logger as log from PIL import Image -from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info +from text_recognizer.data.base_data_module import load_and_print_info from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.iam import IAM from text_recognizer.data.iam_lines import ( @@ -23,9 +23,10 @@ from text_recognizer.data.iam_paragraphs import ( ) from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file +from text_recognizer.metadata import shared as metadata PROCESSED_DATA_DIRNAME = ( - BaseDataModule.data_dirname() / "processed" / "iam_synthetic_paragraphs" + metadata.DATA_DIRNAME / "processed" / "iam_synthetic_paragraphs" ) @@ -117,7 +118,8 @@ class IAMSyntheticParagraphs(IAMParagraphs): x = x[0] if isinstance(x, list) else x data = ( f"Train/val/test sizes: {len(self.data_train)}, 0, 0\n" - f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" + "Train Batch x stats: " + f"{(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" ) return basic + data -- cgit v1.2.3-70-g09d2