diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-25 23:32:50 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-25 23:32:50 +0200 |
commit | 9426cc794d8c28a65bbbf5ae5466a0a343078558 (patch) | |
tree | 44e31b0a7c58597d603ac29a693462aae4b6e9b0 /text_recognizer/data/iam_paragraphs.py | |
parent | 4e60c836fb710baceba570c28c06437db3ad5c9b (diff) |
Efficient net and non working transformer model.
Diffstat (limited to 'text_recognizer/data/iam_paragraphs.py')
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 13 |
1 files changed, 5 insertions, 8 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 62c44f9..24409bc 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -101,7 +101,7 @@ class IAMParagraphs(BaseDataModule): data, targets, transform=get_transform(image_shape=self.dims[1:], augment=augment), - target_transform=get_target_transform(self.word_pieces) + target_transform=get_target_transform(self.word_pieces), ) logger.info(f"Loading IAM paragraph regions and lines for {stage}...") @@ -162,10 +162,7 @@ def get_dataset_properties() -> Dict: "min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines")), }, - "crop_shape": { - "min": crop_shapes.min(axis=0), - "max": crop_shapes.max(axis=0), - }, + "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),}, "aspect_ratio": { "min": aspect_ratio.min(axis=0), "max": aspect_ratio.max(axis=0), @@ -286,9 +283,7 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com ), transforms.ColorJitter(brightness=(0.8, 1.6)), transforms.RandomAffine( - degrees=1, - shear=(-10, 10), - interpolation=InterpolationMode.BILINEAR, + degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, ), ] else: @@ -296,10 +291,12 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com transforms_list.append(transforms.ToTensor()) return transforms.Compose(transforms_list) + def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]: """Transform emnist characters to word pieces.""" return transforms.Compose([WordPiece()]) if word_pieces else None + def _labels_filename(split: str) -> Path: """Return filename of processed labels.""" return PROCESSED_DATA_DIRNAME / split / "_labels.json" |