summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_paragraphs.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_paragraphs.py')
-rw-r--r--text_recognizer/data/iam_paragraphs.py18
1 files changed, 6 insertions, 12 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 0f3a2ce..6189f7d 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -4,7 +4,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
import attr
-from loguru import logger
+from loguru import logger as log
import numpy as np
from PIL import Image, ImageOps
import torchvision.transforms as T
@@ -17,9 +17,8 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.iam import IAM
-from text_recognizer.data.mappings import WordPieceMapping
from text_recognizer.data.transforms import WordPiece
@@ -38,7 +37,6 @@ MAX_LABEL_LENGTH = 682
class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
- num_classes: int = attr.ib()
word_pieces: bool = attr.ib(default=False)
augment: bool = attr.ib(default=True)
train_fraction: float = attr.ib(default=0.8)
@@ -46,21 +44,17 @@ class IAMParagraphs(BaseDataModule):
init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
)
output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
- inverse_mapping: Dict[str, int] = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- _, self.inverse_mapping, _ = emnist_mapping(extra_symbols=[NEW_LINE_TOKEN])
def prepare_data(self) -> None:
"""Create data for training/testing."""
if PROCESSED_DATA_DIRNAME.exists():
return
- logger.info(
+ log.info(
"Cropping IAM paragraph regions and saving them along with labels..."
)
- iam = IAM()
+ iam = IAM(mapping=EmnistMapping())
iam.prepare_data()
properties = {}
@@ -89,7 +83,7 @@ class IAMParagraphs(BaseDataModule):
crops, labels = _load_processed_crops_and_labels(split)
data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
targets = convert_strings_to_labels(
- strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0]
+ strings=labels, mapping=self.mapping.inverse_mapping, length=self.output_dims[0]
)
return BaseDataset(
data,
@@ -98,7 +92,7 @@ class IAMParagraphs(BaseDataModule):
target_transform=get_target_transform(self.word_pieces),
)
- logger.info(f"Loading IAM paragraph regions and lines for {stage}...")
+ log.info(f"Loading IAM paragraph regions and lines for {stage}...")
_validate_data_dims(input_dims=self.dims, output_dims=self.output_dims)
if stage == "fit" or stage is None: