summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_paragraphs.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_paragraphs.py')
-rw-r--r--text_recognizer/data/iam_paragraphs.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 262533f..74b6165 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -11,12 +11,12 @@ import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
+from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import (
BaseDataset,
convert_strings_to_labels,
split_dataset,
)
-from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data.transforms import WordPiece
@@ -55,7 +55,7 @@ class IAMParagraphs(BaseDataModule):
log.info("Cropping IAM paragraph regions and saving them along with labels...")
- iam = IAM(mapping=EmnistMapping(extra_symbols={NEW_LINE_TOKEN,}))
+ iam = IAM(mapping=EmnistMapping(extra_symbols={NEW_LINE_TOKEN}))
iam.prepare_data()
properties = {}
@@ -134,10 +134,14 @@ class IAMParagraphs(BaseDataModule):
f"{len(self.data_train)}, "
f"{len(self.data_val)}, "
f"{len(self.data_test)}\n"
- f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
- f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
- f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n"
- f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n"
+ "Train Batch x stats: "
+ f"{(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
+ "Train Batch y stats: "
+ f"{(y.shape, y.dtype, y.min(), y.max())}\n"
+ "Test Batch x stats: "
+ f"{(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n"
+ "Test Batch y stats: "
+ f"{(yt.shape, yt.dtype, yt.min(), yt.max())}\n"
)
return basic + data
@@ -161,7 +165,7 @@ def get_dataset_properties() -> Dict:
"min": min(_get_property_values("num_lines")),
"max": max(_get_property_values("num_lines")),
},
- "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),},
+ "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0)},
"aspect_ratio": {
"min": aspect_ratio.min(axis=0),
"max": aspect_ratio.max(axis=0),
@@ -316,4 +320,5 @@ def _num_lines(label: str) -> int:
def create_iam_paragraphs() -> None:
+ """Loads and displays dataset statistics."""
load_and_print_info(IAMParagraphs)