summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-30 23:58:51 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-30 23:58:51 +0200
commitdc3110e567f8ac3ad27048d3f346abac623658c0 (patch)
treed18e0d53ca757e5b86deb01db2ebfe352e5c804e
parent41c4d214f754295c5cf0b5b978e81b069336b574 (diff)
Add word piece mapping to IAM lines
-rw-r--r--text_recognizer/data/iam_lines.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index d456e64..0e45c68 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -24,6 +24,7 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.emnist_mapping import EmnistMapping
+from text_recognizer.data.iam_paragraphs import get_target_transform
from text_recognizer.data.iam import IAM
@@ -40,8 +41,9 @@ MAX_LABEL_LENGTH = 89
class IAMLines(BaseDataModule):
"""IAM handwritten lines dataset."""
+ word_pieces: bool = attr.ib(default=False)
augment: bool = attr.ib(default=True)
- fraction: float = attr.ib(default=0.8)
+ train_fraction: float = attr.ib(default=0.8)
dims: Tuple[int, int, int] = attr.ib(
init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
)
@@ -89,11 +91,14 @@ class IAMLines(BaseDataModule):
labels_train, self.mapping.inverse_mapping, length=self.output_dims[0]
)
data_train = BaseDataset(
- x_train, y_train, transform=get_transform(IMAGE_WIDTH, self.augment)
+ x_train,
+ y_train,
+ transform=get_transform(IMAGE_WIDTH, self.augment),
+ target_transform=get_target_transform(self.word_pieces),
)
self.data_train, self.data_val = split_dataset(
- dataset=data_train, fraction=self.fraction, seed=SEED
+ dataset=data_train, fraction=self.train_fraction, seed=SEED
)
if stage == "test" or stage is None:
@@ -108,7 +113,10 @@ class IAMLines(BaseDataModule):
labels_test, self.mapping.inverse_mapping, length=self.output_dims[0]
)
self.data_test = BaseDataset(
- x_test, y_test, transform=get_transform(IMAGE_WIDTH)
+ x_test,
+ y_test,
+ transform=get_transform(IMAGE_WIDTH),
+ target_transform=get_target_transform(self.word_pieces),
)
if stage is None: