summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/data/iam_paragraphs.py19
-rw-r--r--text_recognizer/data/iam_preprocessor.py4
2 files changed, 14 insertions, 9 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 262533f..74b6165 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -11,12 +11,12 @@ import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
+from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import (
BaseDataset,
convert_strings_to_labels,
split_dataset,
)
-from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data.transforms import WordPiece
@@ -55,7 +55,7 @@ class IAMParagraphs(BaseDataModule):
log.info("Cropping IAM paragraph regions and saving them along with labels...")
- iam = IAM(mapping=EmnistMapping(extra_symbols={NEW_LINE_TOKEN,}))
+ iam = IAM(mapping=EmnistMapping(extra_symbols={NEW_LINE_TOKEN}))
iam.prepare_data()
properties = {}
@@ -134,10 +134,14 @@ class IAMParagraphs(BaseDataModule):
f"{len(self.data_train)}, "
f"{len(self.data_val)}, "
f"{len(self.data_test)}\n"
- f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
- f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
- f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n"
- f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n"
+ "Train Batch x stats: "
+ f"{(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
+ "Train Batch y stats: "
+ f"{(y.shape, y.dtype, y.min(), y.max())}\n"
+ "Test Batch x stats: "
+ f"{(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n"
+ "Test Batch y stats: "
+ f"{(yt.shape, yt.dtype, yt.min(), yt.max())}\n"
)
return basic + data
@@ -161,7 +165,7 @@ def get_dataset_properties() -> Dict:
"min": min(_get_property_values("num_lines")),
"max": max(_get_property_values("num_lines")),
},
- "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),},
+ "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0)},
"aspect_ratio": {
"min": aspect_ratio.min(axis=0),
"max": aspect_ratio.max(axis=0),
@@ -316,4 +320,5 @@ def _num_lines(label: str) -> int:
def create_iam_paragraphs() -> None:
+ """Loads and displays dataset statistics."""
load_and_print_info(IAMParagraphs)
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index bcd77b4..700944e 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -7,7 +7,7 @@ import collections
import itertools
from pathlib import Path
import re
-from typing import List, Optional, Union, Set
+from typing import List, Optional, Set, Union
import click
from loguru import logger as log
@@ -140,7 +140,7 @@ class Preprocessor:
if self.special_tokens is not None:
pattern = f"({'|'.join(self.special_tokens)})"
lines = list(filter(None, re.split(pattern, line)))
- return torch.cat([self._to_index(l) for l in lines])
+ return torch.cat([self._to_index(line) for line in lines])
return self._to_index(line)
def to_text(self, indices: List[int]) -> str: