diff options
Diffstat (limited to 'text_recognizer/data/iam_paragraphs.py')
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 32 |
1 files changed, 15 insertions, 17 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index fe60e99..445b788 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -3,6 +3,7 @@ import json from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple +import attr from loguru import logger import numpy as np from PIL import Image, ImageOps @@ -33,33 +34,25 @@ IMAGE_WIDTH = 1280 // IMAGE_SCALE_FACTOR MAX_LABEL_LENGTH = 682 +@attr.s(auto_attribs=True) class IAMParagraphs(BaseDataModule): """IAM handwriting database paragraphs.""" - def __init__( - self, - batch_size: int = 16, - num_workers: int = 0, - train_fraction: float = 0.8, - augment: bool = True, - word_pieces: bool = False, - ) -> None: - super().__init__(batch_size, num_workers) - self.augment = augment - self.word_pieces = word_pieces + augment: bool = attr.ib(default=True) + train_fraction: float = attr.ib(default=0.8) + word_pieces: bool = attr.ib(default=False) + + def __attrs_post_init__(self) -> None: self.mapping, self.inverse_mapping, _ = emnist_mapping( extra_symbols=[NEW_LINE_TOKEN] ) - if word_pieces: + if self.word_pieces: self.mapping = WordPieceMapping() self.train_fraction = train_fraction self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) self.output_dims = (MAX_LABEL_LENGTH, 1) - self.data_train: BaseDataset = None - self.data_val: BaseDataset = None - self.data_test: BaseDataset = None def prepare_data(self) -> None: """Create data for training/testing.""" @@ -166,7 +159,10 @@ def get_dataset_properties() -> Dict: "min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines")), }, - "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),}, + "crop_shape": { + "min": crop_shapes.min(axis=0), + "max": crop_shapes.max(axis=0), + }, "aspect_ratio": { "min": aspect_ratio.min(axis=0), "max": aspect_ratio.max(axis=0), @@ -287,7 +283,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose: ), T.ColorJitter(brightness=(0.8, 1.6)), T.RandomAffine( - degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, + degrees=1, + shear=(-10, 10), + interpolation=InterpolationMode.BILINEAR, ), ] else: |