summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_lines.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_lines.py')
-rw-r--r--text_recognizer/data/iam_lines.py38
1 files changed, 19 insertions, 19 deletions
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index a55ff1c..3bb189c 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -22,16 +22,10 @@ from text_recognizer.data.iam import IAM
from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
from text_recognizer.data.utils import image_utils
+from text_recognizer.metadata import iam_lines as metadata
ImageFile.LOAD_TRUNCATED_IMAGES = True
-SEED = 4711
-PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_lines"
-IMAGE_HEIGHT = 56
-IMAGE_WIDTH = 1024
-MAX_LABEL_LENGTH = 89
-MAX_WORD_PIECE_LENGTH = 72
-
class IAMLines(BaseDataModule):
"""IAM handwritten lines dataset."""
@@ -57,12 +51,12 @@ class IAMLines(BaseDataModule):
num_workers,
pin_memory,
)
- self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
- self.output_dims = (MAX_LABEL_LENGTH, 1)
+ self.dims = (1, metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH)
+ self.output_dims = (metadata.MAX_LABEL_LENGTH, 1)
def prepare_data(self) -> None:
"""Creates the IAM lines dataset if not existing."""
- if PROCESSED_DATA_DIRNAME.exists():
+ if metadata.PROCESSED_DATA_DIRNAME.exists():
return
log.info("Cropping IAM lines regions...")
@@ -76,24 +70,30 @@ class IAMLines(BaseDataModule):
log.info("Saving images, labels, and statistics...")
save_images_and_labels(
- crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME
+ crops_train, labels_train, "train", metadata.PROCESSED_DATA_DIRNAME
+ )
+ save_images_and_labels(
+ crops_test, labels_test, "test", metadata.PROCESSED_DATA_DIRNAME
)
- save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME)
- with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="w") as f:
+ with (metadata.PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(
+ mode="w"
+ ) as f:
f.write(str(aspect_ratios.max()))
def setup(self, stage: str = None) -> None:
"""Load data for training/testing."""
- with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="r") as f:
+ with (metadata.PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(
+ mode="r"
+ ) as f:
max_aspect_ratio = float(f.read())
- image_width = int(IMAGE_HEIGHT * max_aspect_ratio)
- if image_width >= IMAGE_WIDTH:
+ image_width = int(metadata.IMAGE_HEIGHT * max_aspect_ratio)
+ if image_width >= metadata.IMAGE_WIDTH:
raise ValueError("image_width equal or greater than IMAGE_WIDTH")
if stage == "fit" or stage is None:
x_train, labels_train = load_line_crops_and_labels(
- "train", PROCESSED_DATA_DIRNAME
+ "train", metadata.PROCESSED_DATA_DIRNAME
)
if self.output_dims[0] < max([len(labels) for labels in labels_train]) + 2:
raise ValueError("Target length longer than max output length.")
@@ -109,12 +109,12 @@ class IAMLines(BaseDataModule):
)
self.data_train, self.data_val = split_dataset(
- dataset=data_train, fraction=self.train_fraction, seed=SEED
+ dataset=data_train, fraction=self.train_fraction, seed=metadata.SEED
)
if stage == "test" or stage is None:
x_test, labels_test = load_line_crops_and_labels(
- "test", PROCESSED_DATA_DIRNAME
+ "test", metadata.PROCESSED_DATA_DIRNAME
)
if self.output_dims[0] < max([len(labels) for labels in labels_test]) + 2: