diff options
Diffstat (limited to 'text_recognizer/data')
-rw-r--r-- | text_recognizer/data/iam_lines.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index d456e64..0e45c68 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -24,6 +24,7 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.emnist_mapping import EmnistMapping +from text_recognizer.data.iam_paragraphs import get_target_transform from text_recognizer.data.iam import IAM @@ -40,8 +41,9 @@ MAX_LABEL_LENGTH = 89 class IAMLines(BaseDataModule): """IAM handwritten lines dataset.""" + word_pieces: bool = attr.ib(default=False) augment: bool = attr.ib(default=True) - fraction: float = attr.ib(default=0.8) + train_fraction: float = attr.ib(default=0.8) dims: Tuple[int, int, int] = attr.ib( init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) ) @@ -89,11 +91,14 @@ class IAMLines(BaseDataModule): labels_train, self.mapping.inverse_mapping, length=self.output_dims[0] ) data_train = BaseDataset( - x_train, y_train, transform=get_transform(IMAGE_WIDTH, self.augment) + x_train, + y_train, + transform=get_transform(IMAGE_WIDTH, self.augment), + target_transform=get_target_transform(self.word_pieces), ) self.data_train, self.data_val = split_dataset( - dataset=data_train, fraction=self.fraction, seed=SEED + dataset=data_train, fraction=self.train_fraction, seed=SEED ) if stage == "test" or stage is None: @@ -108,7 +113,10 @@ class IAMLines(BaseDataModule): labels_test, self.mapping.inverse_mapping, length=self.output_dims[0] ) self.data_test = BaseDataset( - x_test, y_test, transform=get_transform(IMAGE_WIDTH) + x_test, + y_test, + transform=get_transform(IMAGE_WIDTH), + target_transform=get_target_transform(self.word_pieces), ) if stage is None: |