summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist_lines.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/emnist_lines.py')
-rw-r--r--text_recognizer/data/emnist_lines.py49
1 files changed, 36 insertions, 13 deletions
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 43d55b9..062257d 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -1,9 +1,8 @@
"""Dataset of generated text from EMNIST characters."""
from collections import defaultdict
from pathlib import Path
-from typing import DefaultDict, List, Tuple
+from typing import Callable, DefaultDict, List, Optional, Tuple, Type
-from attrs import define, field
import h5py
from loguru import logger as log
import numpy as np
@@ -17,6 +16,7 @@ from text_recognizer.data.base_data_module import (
)
from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels
from text_recognizer.data.emnist import EMNIST
+from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
from text_recognizer.data.utils.sentence_generator import SentenceGenerator
@@ -33,22 +33,45 @@ IMAGE_X_PADDING = 28
MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
-@define(auto_attribs=True, repr=False)
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST."""
- max_length: int = field(default=128)
- min_overlap: float = field(default=0.0)
- max_overlap: float = field(default=0.33)
- num_train: int = field(default=10_000)
- num_val: int = field(default=2_000)
- num_test: int = field(default=2_000)
- emnist: EMNIST = field(init=False, default=None)
+ def __init__(
+ self,
+ mapping: Type[AbstractMapping],
+ transform: Optional[Callable] = None,
+ test_transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ train_fraction: float = 0.8,
+ batch_size: int = 16,
+ num_workers: int = 0,
+ pin_memory: bool = True,
+ max_length: int = 128,
+ min_overlap: float = 0.0,
+ max_overlap: float = 0.33,
+ num_train: int = 10_000,
+ num_val: int = 2_000,
+ num_test: int = 2_000,
+ ) -> None:
+ super().__init__(
+ mapping,
+ transform,
+ test_transform,
+ target_transform,
+ train_fraction,
+ batch_size,
+ num_workers,
+ pin_memory,
+ )
- def __attrs_post_init__(self) -> None:
- """Post init constructor."""
- self.emnist = EMNIST(mapping=self.mapping)
+ self.max_length = max_length
+ self.min_overlap = min_overlap
+ self.max_overlap = max_overlap
+ self.num_train = num_train
+ self.num_val = num_val
+ self.num_test = num_test
+ self.emnist = EMNIST(mapping=self.mapping)
max_width = (
int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap))
+ IMAGE_X_PADDING