diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:58:51 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:58:51 +0200 |
commit | dc3110e567f8ac3ad27048d3f346abac623658c0 (patch) | |
tree | d18e0d53ca757e5b86deb01db2ebfe352e5c804e | |
parent | 41c4d214f754295c5cf0b5b978e81b069336b574 (diff) |
Add word piece mapping to IAM lines
-rw-r--r-- | text_recognizer/data/iam_lines.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index d456e64..0e45c68 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -24,6 +24,7 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.emnist_mapping import EmnistMapping +from text_recognizer.data.iam_paragraphs import get_target_transform from text_recognizer.data.iam import IAM @@ -40,8 +41,9 @@ MAX_LABEL_LENGTH = 89 class IAMLines(BaseDataModule): """IAM handwritten lines dataset.""" + word_pieces: bool = attr.ib(default=False) augment: bool = attr.ib(default=True) - fraction: float = attr.ib(default=0.8) + train_fraction: float = attr.ib(default=0.8) dims: Tuple[int, int, int] = attr.ib( init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) ) @@ -89,11 +91,14 @@ class IAMLines(BaseDataModule): labels_train, self.mapping.inverse_mapping, length=self.output_dims[0] ) data_train = BaseDataset( - x_train, y_train, transform=get_transform(IMAGE_WIDTH, self.augment) + x_train, + y_train, + transform=get_transform(IMAGE_WIDTH, self.augment), + target_transform=get_target_transform(self.word_pieces), ) self.data_train, self.data_val = split_dataset( - dataset=data_train, fraction=self.fraction, seed=SEED + dataset=data_train, fraction=self.train_fraction, seed=SEED ) if stage == "test" or stage is None: @@ -108,7 +113,10 @@ class IAMLines(BaseDataModule): labels_test, self.mapping.inverse_mapping, length=self.output_dims[0] ) self.data_test = BaseDataset( - x_test, y_test, transform=get_transform(IMAGE_WIDTH) + x_test, + y_test, + transform=get_transform(IMAGE_WIDTH), + target_transform=get_target_transform(self.word_pieces), ) if stage is None: |