summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_extended_paragraphs.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_extended_paragraphs.py')
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py29
1 files changed, 16 insertions, 13 deletions
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 8b3a46c..87b8ef1 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -1,21 +1,17 @@
"""IAM original and sythetic dataset class."""
import attr
-from typing import Optional, Tuple
from torch.utils.data import ConcatDataset
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.iam_paragraphs import IAMParagraphs
from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
+from text_recognizer.data.transforms.load_transform import load_transform_from_file
@attr.s(auto_attribs=True, repr=False)
class IAMExtendedParagraphs(BaseDataModule):
-
- augment: bool = attr.ib(default=True)
- train_fraction: float = attr.ib(default=0.8)
- word_pieces: bool = attr.ib(default=False)
- resize: Optional[Tuple[int, int]] = attr.ib(default=None)
+ """A dataset with synthetic and real handwritten paragraph."""
def __attrs_post_init__(self) -> None:
self.iam_paragraphs = IAMParagraphs(
@@ -23,18 +19,18 @@ class IAMExtendedParagraphs(BaseDataModule):
batch_size=self.batch_size,
num_workers=self.num_workers,
train_fraction=self.train_fraction,
- augment=self.augment,
- word_pieces=self.word_pieces,
- resize=self.resize,
+ transform=self.transform,
+ test_transform=self.test_transform,
+ target_transform=self.target_transform,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
mapping=self.mapping,
batch_size=self.batch_size,
num_workers=self.num_workers,
train_fraction=self.train_fraction,
- augment=self.augment,
- word_pieces=self.word_pieces,
- resize=self.resize,
+ transform=self.transform,
+ test_transform=self.test_transform,
+ target_transform=self.target_transform,
)
self.dims = self.iam_paragraphs.dims
@@ -69,6 +65,8 @@ class IAMExtendedParagraphs(BaseDataModule):
x, y = next(iter(self.train_dataloader()))
xt, yt = next(iter(self.test_dataloader()))
+ x = x[0] if isinstance(x, list) else x
+ xt = xt[0] if isinstance(xt, list) else xt
data = (
f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
@@ -80,4 +78,9 @@ class IAMExtendedParagraphs(BaseDataModule):
def show_dataset_info() -> None:
- load_and_print_info(IAMExtendedParagraphs)
+ """Displays Iam extended dataset information."""
+ transform = load_transform_from_file("transform/paragraphs.yaml")
+ test_transform = load_transform_from_file("test_transform/paragraphs_test.yaml")
+ load_and_print_info(
+ IAMExtendedParagraphs(transform=transform, test_transform=test_transform)
+ )