summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_paragraphs.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-25 23:32:50 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-25 23:32:50 +0200
commit9426cc794d8c28a65bbbf5ae5466a0a343078558 (patch)
tree44e31b0a7c58597d603ac29a693462aae4b6e9b0 /text_recognizer/data/iam_paragraphs.py
parent4e60c836fb710baceba570c28c06437db3ad5c9b (diff)
Efficient net and non working transformer model.
Diffstat (limited to 'text_recognizer/data/iam_paragraphs.py')
-rw-r--r--text_recognizer/data/iam_paragraphs.py13
1 files changed, 5 insertions, 8 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 62c44f9..24409bc 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -101,7 +101,7 @@ class IAMParagraphs(BaseDataModule):
data,
targets,
transform=get_transform(image_shape=self.dims[1:], augment=augment),
- target_transform=get_target_transform(self.word_pieces)
+ target_transform=get_target_transform(self.word_pieces),
)
logger.info(f"Loading IAM paragraph regions and lines for {stage}...")
@@ -162,10 +162,7 @@ def get_dataset_properties() -> Dict:
"min": min(_get_property_values("num_lines")),
"max": max(_get_property_values("num_lines")),
},
- "crop_shape": {
- "min": crop_shapes.min(axis=0),
- "max": crop_shapes.max(axis=0),
- },
+ "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),},
"aspect_ratio": {
"min": aspect_ratio.min(axis=0),
"max": aspect_ratio.max(axis=0),
@@ -286,9 +283,7 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com
),
transforms.ColorJitter(brightness=(0.8, 1.6)),
transforms.RandomAffine(
- degrees=1,
- shear=(-10, 10),
- interpolation=InterpolationMode.BILINEAR,
+ degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
),
]
else:
@@ -296,10 +291,12 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com
transforms_list.append(transforms.ToTensor())
return transforms.Compose(transforms_list)
+
def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]:
"""Transform emnist characters to word pieces."""
return transforms.Compose([WordPiece()]) if word_pieces else None
+
def _labels_filename(split: str) -> Path:
"""Return filename of processed labels."""
return PROCESSED_DATA_DIRNAME / split / "_labels.json"