summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_paragraphs.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_paragraphs.py')
-rw-r--r--text_recognizer/data/iam_paragraphs.py38
1 files changed, 14 insertions, 24 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index c7d5229..eec1b1f 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -18,17 +18,7 @@ from text_recognizer.data.base_dataset import (
from text_recognizer.data.iam import IAM
from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
-
-PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs"
-
-NEW_LINE_TOKEN = "\n"
-
-SEED = 4711
-IMAGE_SCALE_FACTOR = 2
-IMAGE_HEIGHT = 1152 // IMAGE_SCALE_FACTOR
-IMAGE_WIDTH = 1280 // IMAGE_SCALE_FACTOR
-MAX_LABEL_LENGTH = 682
-MAX_WORD_PIECE_LENGTH = 451
+from text_recognizer.metadata import iam_paragraphs as metadata
class IAMParagraphs(BaseDataModule):
@@ -55,17 +45,17 @@ class IAMParagraphs(BaseDataModule):
num_workers,
pin_memory,
)
- self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
- self.output_dims = (MAX_LABEL_LENGTH, 1)
+ self.dims = (1, metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH)
+ self.output_dims = (metadata.MAX_LABEL_LENGTH, 1)
def prepare_data(self) -> None:
"""Create data for training/testing."""
- if PROCESSED_DATA_DIRNAME.exists():
+ if metadata.PROCESSED_DATA_DIRNAME.exists():
return
log.info("Cropping IAM paragraph regions and saving them along with labels...")
- iam = IAM(mapping=EmnistMapping(extra_symbols={NEW_LINE_TOKEN}))
+ iam = IAM(mapping=EmnistMapping(extra_symbols={metadata.NEW_LINE_TOKEN}))
iam.prepare_data()
properties = {}
@@ -84,7 +74,7 @@ class IAMParagraphs(BaseDataModule):
}
)
- with (PROCESSED_DATA_DIRNAME / "_properties.json").open("w") as f:
+ with (metadata.PROCESSED_DATA_DIRNAME / "_properties.json").open("w") as f:
json.dump(properties, f, indent=4)
def setup(self, stage: str = None) -> None:
@@ -94,7 +84,7 @@ class IAMParagraphs(BaseDataModule):
split: str, transform: T.Compose, target_transform: T.Compose
) -> BaseDataset:
crops, labels = _load_processed_crops_and_labels(split)
- data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
+ data = [resize_image(crop, metadata.IMAGE_SCALE_FACTOR) for crop in crops]
targets = convert_strings_to_labels(
strings=labels,
mapping=self.mapping.inverse_mapping,
@@ -117,7 +107,7 @@ class IAMParagraphs(BaseDataModule):
target_transform=self.target_transform,
)
self.data_train, self.data_val = split_dataset(
- dataset=data_train, fraction=self.train_fraction, seed=SEED
+ dataset=data_train, fraction=self.train_fraction, seed=metadata.SEED
)
if stage == "test" or stage is None:
@@ -162,7 +152,7 @@ class IAMParagraphs(BaseDataModule):
def get_dataset_properties() -> Dict:
"""Return properties describing the overall dataset."""
- with (PROCESSED_DATA_DIRNAME / "_properties.json").open("r") as f:
+ with (metadata.PROCESSED_DATA_DIRNAME / "_properties.json").open("r") as f:
properties = json.load(f)
def _get_property_values(key: str) -> List:
@@ -193,7 +183,7 @@ def _validate_data_dims(
"""Validates input and output dimensions against the properties of the dataset."""
properties = get_dataset_properties()
- max_image_shape = properties["crop_shape"]["max"] / IMAGE_SCALE_FACTOR
+ max_image_shape = properties["crop_shape"]["max"] / metadata.IMAGE_SCALE_FACTOR
if (
input_dims is not None
and input_dims[1] < max_image_shape[0]
@@ -246,7 +236,7 @@ def _get_paragraph_crops_and_labels(
lines = iam.line_strings_by_id[id_]
crops[id_] = image.crop(paragraph_box)
- labels[id_] = NEW_LINE_TOKEN.join(lines)
+ labels[id_] = metadata.NEW_LINE_TOKEN.join(lines)
if len(crops) != len(labels):
raise ValueError(f"Crops ({len(crops)}) does not match labels ({len(labels)})")
@@ -258,7 +248,7 @@ def _save_crops_and_labels(
crops: Dict[str, Image.Image], labels: Dict[str, str], split: str
) -> None:
"""Save crops, labels, and shapes of crops of a split."""
- (PROCESSED_DATA_DIRNAME / split).mkdir(parents=True, exist_ok=True)
+ (metadata.PROCESSED_DATA_DIRNAME / split).mkdir(parents=True, exist_ok=True)
with _labels_filename(split).open("w") as f:
json.dump(labels, f, indent=4)
@@ -289,12 +279,12 @@ def _load_processed_crops_and_labels(
def _labels_filename(split: str) -> Path:
"""Return filename of processed labels."""
- return PROCESSED_DATA_DIRNAME / split / "_labels.json"
+ return metadata.PROCESSED_DATA_DIRNAME / split / "_labels.json"
def _crop_filename(id: str, split: str) -> Path:
"""Return filename of processed crop."""
- return PROCESSED_DATA_DIRNAME / split / f"{id}.png"
+ return metadata.PROCESSED_DATA_DIRNAME / split / f"{id}.png"
def _num_lines(label: str) -> int: