diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 19 | ||||
-rw-r--r-- | text_recognizer/data/iam_preprocessor.py | 4 |
2 files changed, 14 insertions, 9 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 262533f..74b6165 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -11,12 +11,12 @@ import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm +from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.base_dataset import ( BaseDataset, convert_strings_to_labels, split_dataset, ) -from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.emnist_mapping import EmnistMapping from text_recognizer.data.iam import IAM from text_recognizer.data.transforms import WordPiece @@ -55,7 +55,7 @@ class IAMParagraphs(BaseDataModule): log.info("Cropping IAM paragraph regions and saving them along with labels...") - iam = IAM(mapping=EmnistMapping(extra_symbols={NEW_LINE_TOKEN,})) + iam = IAM(mapping=EmnistMapping(extra_symbols={NEW_LINE_TOKEN})) iam.prepare_data() properties = {} @@ -134,10 +134,14 @@ class IAMParagraphs(BaseDataModule): f"{len(self.data_train)}, " f"{len(self.data_val)}, " f"{len(self.data_test)}\n" - f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" - f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" - f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" - f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" + "Train Batch x stats: " + f"{(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" + "Train Batch y stats: " + f"{(y.shape, y.dtype, y.min(), y.max())}\n" + "Test Batch x stats: " + f"{(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" + "Test Batch y stats: " + f"{(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data @@ -161,7 +165,7 @@ def get_dataset_properties() -> Dict: "min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines")), }, - "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),}, + "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0)}, "aspect_ratio": { "min": aspect_ratio.min(axis=0), "max": aspect_ratio.max(axis=0), @@ -316,4 +320,5 @@ def _num_lines(label: str) -> int: def create_iam_paragraphs() -> None: + """Loads and displays dataset statistics.""" load_and_print_info(IAMParagraphs) diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index bcd77b4..700944e 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -7,7 +7,7 @@ import collections import itertools from pathlib import Path import re -from typing import List, Optional, Union, Set +from typing import List, Optional, Set, Union import click from loguru import logger as log @@ -140,7 +140,7 @@ class Preprocessor: if self.special_tokens is not None: pattern = f"({'|'.join(self.special_tokens)})" lines = list(filter(None, re.split(pattern, line))) - return torch.cat([self._to_index(l) for l in lines]) + return torch.cat([self._to_index(line) for line in lines]) return self._to_index(line) def to_text(self, indices: List[int]) -> str: |