diff options
Diffstat (limited to 'text_recognizer/data/iam_lines.py')
-rw-r--r-- | text_recognizer/data/iam_lines.py | 38 |
1 files changed, 19 insertions, 19 deletions
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index a55ff1c..3bb189c 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -22,16 +22,10 @@ from text_recognizer.data.iam import IAM from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils import image_utils +from text_recognizer.metadata import iam_lines as metadata ImageFile.LOAD_TRUNCATED_IMAGES = True -SEED = 4711 -PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_lines" -IMAGE_HEIGHT = 56 -IMAGE_WIDTH = 1024 -MAX_LABEL_LENGTH = 89 -MAX_WORD_PIECE_LENGTH = 72 - class IAMLines(BaseDataModule): """IAM handwritten lines dataset.""" @@ -57,12 +51,12 @@ class IAMLines(BaseDataModule): num_workers, pin_memory, ) - self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) - self.output_dims = (MAX_LABEL_LENGTH, 1) + self.dims = (1, metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH) + self.output_dims = (metadata.MAX_LABEL_LENGTH, 1) def prepare_data(self) -> None: """Creates the IAM lines dataset if not existing.""" - if PROCESSED_DATA_DIRNAME.exists(): + if metadata.PROCESSED_DATA_DIRNAME.exists(): return log.info("Cropping IAM lines regions...") @@ -76,24 +70,30 @@ class IAMLines(BaseDataModule): log.info("Saving images, labels, and statistics...") save_images_and_labels( - crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME + crops_train, labels_train, "train", metadata.PROCESSED_DATA_DIRNAME + ) + save_images_and_labels( + crops_test, labels_test, "test", metadata.PROCESSED_DATA_DIRNAME ) - save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME) - with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="w") as f: + with (metadata.PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open( + mode="w" + ) as f: f.write(str(aspect_ratios.max())) def setup(self, stage: str = None) -> None: """Load data for training/testing.""" - with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="r") as f: + with (metadata.PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open( + mode="r" + ) as f: max_aspect_ratio = float(f.read()) - image_width = int(IMAGE_HEIGHT * max_aspect_ratio) - if image_width >= IMAGE_WIDTH: + image_width = int(metadata.IMAGE_HEIGHT * max_aspect_ratio) + if image_width >= metadata.IMAGE_WIDTH: raise ValueError("image_width equal or greater than IMAGE_WIDTH") if stage == "fit" or stage is None: x_train, labels_train = load_line_crops_and_labels( - "train", PROCESSED_DATA_DIRNAME + "train", metadata.PROCESSED_DATA_DIRNAME ) if self.output_dims[0] < max([len(labels) for labels in labels_train]) + 2: raise ValueError("Target length longer than max output length.") @@ -109,12 +109,12 @@ class IAMLines(BaseDataModule): ) self.data_train, self.data_val = split_dataset( - dataset=data_train, fraction=self.train_fraction, seed=SEED + dataset=data_train, fraction=self.train_fraction, seed=metadata.SEED ) if stage == "test" or stage is None: x_test, labels_test = load_line_crops_and_labels( - "test", PROCESSED_DATA_DIRNAME + "test", metadata.PROCESSED_DATA_DIRNAME ) if self.output_dims[0] < max([len(labels) for labels in labels_test]) + 2: |