diff options
Diffstat (limited to 'text_recognizer/data/iam_paragraphs.py')
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 18 |
1 files changed, 6 insertions, 12 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 0f3a2ce..6189f7d 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple import attr -from loguru import logger +from loguru import logger as log import numpy as np from PIL import Image, ImageOps import torchvision.transforms as T @@ -17,9 +17,8 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info -from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.iam import IAM -from text_recognizer.data.mappings import WordPieceMapping from text_recognizer.data.transforms import WordPiece @@ -38,7 +37,6 @@ MAX_LABEL_LENGTH = 682 class IAMParagraphs(BaseDataModule): """IAM handwriting database paragraphs.""" - num_classes: int = attr.ib() word_pieces: bool = attr.ib(default=False) augment: bool = attr.ib(default=True) train_fraction: float = attr.ib(default=0.8) @@ -46,21 +44,17 @@ class IAMParagraphs(BaseDataModule): init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) ) output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) - inverse_mapping: Dict[str, int] = attr.ib(init=False) - - def __attrs_post_init__(self) -> None: - _, self.inverse_mapping, _ = emnist_mapping(extra_symbols=[NEW_LINE_TOKEN]) def prepare_data(self) -> None: """Create data for training/testing.""" if PROCESSED_DATA_DIRNAME.exists(): return - logger.info( + log.info( "Cropping IAM paragraph regions and saving them along with labels..." ) - iam = IAM() + iam = IAM(mapping=EmnistMapping()) iam.prepare_data() properties = {} @@ -89,7 +83,7 @@ class IAMParagraphs(BaseDataModule): crops, labels = _load_processed_crops_and_labels(split) data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops] targets = convert_strings_to_labels( - strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0] + strings=labels, mapping=self.mapping.inverse_mapping, length=self.output_dims[0] ) return BaseDataset( data, @@ -98,7 +92,7 @@ class IAMParagraphs(BaseDataModule): target_transform=get_target_transform(self.word_pieces), ) - logger.info(f"Loading IAM paragraph regions and lines for {stage}...") + log.info(f"Loading IAM paragraph regions and lines for {stage}...") _validate_data_dims(input_dims=self.dims, output_dims=self.output_dims) if stage == "fit" or stage is None: |