summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_lines.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_lines.py')
-rw-r--r--text_recognizer/data/iam_lines.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index d456e64..0e45c68 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -24,6 +24,7 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.emnist_mapping import EmnistMapping
+from text_recognizer.data.iam_paragraphs import get_target_transform
from text_recognizer.data.iam import IAM
@@ -40,8 +41,9 @@ MAX_LABEL_LENGTH = 89
class IAMLines(BaseDataModule):
"""IAM handwritten lines dataset."""
+ word_pieces: bool = attr.ib(default=False)
augment: bool = attr.ib(default=True)
- fraction: float = attr.ib(default=0.8)
+ train_fraction: float = attr.ib(default=0.8)
dims: Tuple[int, int, int] = attr.ib(
init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
)
@@ -89,11 +91,14 @@ class IAMLines(BaseDataModule):
labels_train, self.mapping.inverse_mapping, length=self.output_dims[0]
)
data_train = BaseDataset(
- x_train, y_train, transform=get_transform(IMAGE_WIDTH, self.augment)
+ x_train,
+ y_train,
+ transform=get_transform(IMAGE_WIDTH, self.augment),
+ target_transform=get_target_transform(self.word_pieces),
)
self.data_train, self.data_val = split_dataset(
- dataset=data_train, fraction=self.fraction, seed=SEED
+ dataset=data_train, fraction=self.train_fraction, seed=SEED
)
if stage == "test" or stage is None:
@@ -108,7 +113,10 @@ class IAMLines(BaseDataModule):
labels_test, self.mapping.inverse_mapping, length=self.output_dims[0]
)
self.data_test = BaseDataset(
- x_test, y_test, transform=get_transform(IMAGE_WIDTH)
+ x_test,
+ y_test,
+ transform=get_transform(IMAGE_WIDTH),
+ target_transform=get_target_transform(self.word_pieces),
)
if stage is None: