summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_paragraphs.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_paragraphs.py')
-rw-r--r--text_recognizer/data/iam_paragraphs.py32
1 files changed, 15 insertions, 17 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index fe60e99..445b788 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -3,6 +3,7 @@ import json
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
+import attr
from loguru import logger
import numpy as np
from PIL import Image, ImageOps
@@ -33,33 +34,25 @@ IMAGE_WIDTH = 1280 // IMAGE_SCALE_FACTOR
MAX_LABEL_LENGTH = 682
+@attr.s(auto_attribs=True)
class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
- def __init__(
- self,
- batch_size: int = 16,
- num_workers: int = 0,
- train_fraction: float = 0.8,
- augment: bool = True,
- word_pieces: bool = False,
- ) -> None:
- super().__init__(batch_size, num_workers)
- self.augment = augment
- self.word_pieces = word_pieces
+ augment: bool = attr.ib(default=True)
+ train_fraction: float = attr.ib(default=0.8)
+ word_pieces: bool = attr.ib(default=False)
+
+ def __attrs_post_init__(self) -> None:
self.mapping, self.inverse_mapping, _ = emnist_mapping(
extra_symbols=[NEW_LINE_TOKEN]
)
- if word_pieces:
+ if self.word_pieces:
self.mapping = WordPieceMapping()
self.train_fraction = train_fraction
self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
self.output_dims = (MAX_LABEL_LENGTH, 1)
- self.data_train: BaseDataset = None
- self.data_val: BaseDataset = None
- self.data_test: BaseDataset = None
def prepare_data(self) -> None:
"""Create data for training/testing."""
@@ -166,7 +159,10 @@ def get_dataset_properties() -> Dict:
"min": min(_get_property_values("num_lines")),
"max": max(_get_property_values("num_lines")),
},
- "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),},
+ "crop_shape": {
+ "min": crop_shapes.min(axis=0),
+ "max": crop_shapes.max(axis=0),
+ },
"aspect_ratio": {
"min": aspect_ratio.min(axis=0),
"max": aspect_ratio.max(axis=0),
@@ -287,7 +283,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose:
),
T.ColorJitter(brightness=(0.8, 1.6)),
T.RandomAffine(
- degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
+ degrees=1,
+ shear=(-10, 10),
+ interpolation=InterpolationMode.BILINEAR,
),
]
else: