summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_paragraphs.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_paragraphs.py')
-rw-r--r--text_recognizer/data/iam_paragraphs.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 402a8d4..f588587 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -93,14 +93,14 @@ class IAMParagraphs(BaseDataModule):
def _load_dataset(split: str, augment: bool) -> BaseDataset:
crops, labels = _load_processed_crops_and_labels(split)
- data = [_resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
+ data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
targets = convert_strings_to_labels(
strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0]
)
return BaseDataset(
data,
targets,
- transform=_get_transform(image_shape=self.dims[1:], augment=augment),
+ transform=get_transform(image_shape=self.dims[1:], augment=augment),
)
logger.info(f"Loading IAM paragraph regions and lines for {stage}...")
@@ -142,7 +142,7 @@ class IAMParagraphs(BaseDataModule):
return basic + data
-def _get_dataset_properties() -> Dict:
+def get_dataset_properties() -> Dict:
"""Return properties describing the overall dataset."""
with (PROCESSED_DATA_DIRNAME / "_properties.json").open("r") as f:
properties = json.load(f)
@@ -173,7 +173,7 @@ def _validate_data_dims(
input_dims: Optional[Tuple[int, ...]], output_dims: Optional[Tuple[int, ...]]
) -> None:
"""Validates input and output dimensions against the properties of the dataset."""
- properties = _get_dataset_properties()
+ properties = get_dataset_properties()
max_image_shape = properties["crop_shape"]["max"] / IMAGE_SCALE_FACTOR
if (
@@ -192,7 +192,7 @@ def _validate_data_dims(
)
-def _resize_image(image: Image.Image, scale_factor: int) -> Image.Image:
+def resize_image(image: Image.Image, scale_factor: int) -> Image.Image:
"""Resize image by scale factor."""
if scale_factor == 1:
return image
@@ -219,7 +219,7 @@ def _get_paragraph_crops_and_labels(
image = ImageOps.invert(image)
line_regions = iam.line_regions_by_id[id_]
- parameter_box = [
+ paragraph_box = [
min([region["x1"] for region in line_regions]),
min([region["y1"] for region in line_regions]),
max([region["x2"] for region in line_regions]),
@@ -227,7 +227,7 @@ def _get_paragraph_crops_and_labels(
]
lines = iam.line_strings_by_id[id_]
- crops[id_] = image.crop(parameter_box)
+ crops[id_] = image.crop(paragraph_box)
labels[id_] = NEW_LINE_TOKEN.join(lines)
if len(crops) != len(labels):
@@ -269,7 +269,7 @@ def _load_processed_crops_and_labels(
return ordered_crops, ordered_labels
-def _get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Compose:
+def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Compose:
"""Get transformations for images."""
if augment:
transforms_list = [