summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_lines.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_lines.py')
-rw-r--r--text_recognizer/data/iam_lines.py21
1 files changed, 8 insertions, 13 deletions
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index b7f3fdd..1c63729 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -2,15 +2,14 @@
If not created, will generate a handwritten lines dataset from the IAM paragraphs
dataset.
-
"""
import json
from pathlib import Path
import random
-from typing import Dict, List, Sequence, Tuple
+from typing import List, Sequence, Tuple
import attr
-from loguru import logger
+from loguru import logger as log
from PIL import Image, ImageFile, ImageOps
import numpy as np
from torch import Tensor
@@ -23,7 +22,7 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data import image_utils
@@ -48,17 +47,13 @@ class IAMLines(BaseDataModule):
)
output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
- def __attrs_post_init__(self) -> None:
- # TODO: refactor this
- self.mapping, self.inverse_mapping, _ = emnist_mapping()
-
def prepare_data(self) -> None:
"""Creates the IAM lines dataset if not existing."""
if PROCESSED_DATA_DIRNAME.exists():
return
- logger.info("Cropping IAM lines regions...")
- iam = IAM()
+ log.info("Cropping IAM lines regions...")
+ iam = IAM(mapping=EmnistMapping())
iam.prepare_data()
crops_train, labels_train = line_crops_and_labels(iam, "train")
crops_test, labels_test = line_crops_and_labels(iam, "test")
@@ -66,7 +61,7 @@ class IAMLines(BaseDataModule):
shapes = np.array([crop.size for crop in crops_train + crops_test])
aspect_ratios = shapes[:, 0] / shapes[:, 1]
- logger.info("Saving images, labels, and statistics...")
+ log.info("Saving images, labels, and statistics...")
save_images_and_labels(
crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME
)
@@ -91,7 +86,7 @@ class IAMLines(BaseDataModule):
raise ValueError("Target length longer than max output length.")
y_train = convert_strings_to_labels(
- labels_train, self.inverse_mapping, length=self.output_dims[0]
+ labels_train, self.mapping.inverse_mapping, length=self.output_dims[0]
)
data_train = BaseDataset(
x_train, y_train, transform=get_transform(IMAGE_WIDTH, self.augment)
@@ -110,7 +105,7 @@ class IAMLines(BaseDataModule):
raise ValueError("Taget length longer than max output length.")
y_test = convert_strings_to_labels(
- labels_test, self.inverse_mapping, length=self.output_dims[0]
+ labels_test, self.mapping.inverse_mapping, length=self.output_dims[0]
)
self.data_test = BaseDataset(
x_test, y_test, transform=get_transform(IMAGE_WIDTH)