summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-25 23:32:50 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-25 23:32:50 +0200
commit9426cc794d8c28a65bbbf5ae5466a0a343078558 (patch)
tree44e31b0a7c58597d603ac29a693462aae4b6e9b0 /text_recognizer/data
parent4e60c836fb710baceba570c28c06437db3ad5c9b (diff)
Efficient net and non working transformer model.
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py12
-rw-r--r--text_recognizer/data/iam_paragraphs.py13
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py2
-rw-r--r--text_recognizer/data/mappings.py12
-rw-r--r--text_recognizer/data/transforms.py4
5 files changed, 19 insertions, 24 deletions
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 2380660..0a30a42 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -19,18 +19,10 @@ class IAMExtendedParagraphs(BaseDataModule):
super().__init__(batch_size, num_workers)
self.iam_paragraphs = IAMParagraphs(
- batch_size,
- num_workers,
- train_fraction,
- augment,
- word_pieces,
+ batch_size, num_workers, train_fraction, augment, word_pieces,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
- batch_size,
- num_workers,
- train_fraction,
- augment,
- word_pieces,
+ batch_size, num_workers, train_fraction, augment, word_pieces,
)
self.dims = self.iam_paragraphs.dims
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 62c44f9..24409bc 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -101,7 +101,7 @@ class IAMParagraphs(BaseDataModule):
data,
targets,
transform=get_transform(image_shape=self.dims[1:], augment=augment),
- target_transform=get_target_transform(self.word_pieces)
+ target_transform=get_target_transform(self.word_pieces),
)
logger.info(f"Loading IAM paragraph regions and lines for {stage}...")
@@ -162,10 +162,7 @@ def get_dataset_properties() -> Dict:
"min": min(_get_property_values("num_lines")),
"max": max(_get_property_values("num_lines")),
},
- "crop_shape": {
- "min": crop_shapes.min(axis=0),
- "max": crop_shapes.max(axis=0),
- },
+ "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),},
"aspect_ratio": {
"min": aspect_ratio.min(axis=0),
"max": aspect_ratio.max(axis=0),
@@ -286,9 +283,7 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com
),
transforms.ColorJitter(brightness=(0.8, 1.6)),
transforms.RandomAffine(
- degrees=1,
- shear=(-10, 10),
- interpolation=InterpolationMode.BILINEAR,
+ degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
),
]
else:
@@ -296,10 +291,12 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com
transforms_list.append(transforms.ToTensor())
return transforms.Compose(transforms_list)
+
def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]:
"""Transform emnist characters to word pieces."""
return transforms.Compose([WordPiece()]) if word_pieces else None
+
def _labels_filename(split: str) -> Path:
"""Return filename of processed labels."""
return PROCESSED_DATA_DIRNAME / split / "_labels.json"
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 4ccc5c2..78e6c05 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -97,7 +97,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
transform=get_transform(
image_shape=self.dims[1:], augment=self.augment
),
- target_transform=get_target_transform(self.word_pieces)
+ target_transform=get_target_transform(self.word_pieces),
)
def __repr__(self) -> str:
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
index f4016ba..190febe 100644
--- a/text_recognizer/data/mappings.py
+++ b/text_recognizer/data/mappings.py
@@ -58,13 +58,13 @@ class WordPieceMapping(EmnistMapping):
def __init__(
self,
num_features: int = 1000,
- tokens: str = "iamdb_1kwp_tokens_1000.txt" ,
+ tokens: str = "iamdb_1kwp_tokens_1000.txt",
lexicon: str = "iamdb_1kwp_lex_1000.txt",
data_dir: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
- extra_symbols: Optional[Sequence[str]] = ("\n", ),
+ extra_symbols: Optional[Sequence[str]] = ("\n",),
) -> None:
super().__init__(extra_symbols)
self.wordpiece_processor = self._configure_wordpiece_processor(
@@ -90,7 +90,13 @@ class WordPieceMapping(EmnistMapping):
extra_symbols: Optional[Sequence[str]],
) -> Preprocessor:
data_dir = (
- (Path(__file__).resolve().parents[2] / "data" / "downloaded" / "iam" / "iamdb")
+ (
+ Path(__file__).resolve().parents[2]
+ / "data"
+ / "downloaded"
+ / "iam"
+ / "iamdb"
+ )
if data_dir is None
else Path(data_dir)
)
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index 8d1bedd..d0f1f35 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -13,7 +13,7 @@ class WordPiece:
def __init__(
self,
num_features: int = 1000,
- tokens: str = "iamdb_1kwp_tokens_1000.txt" ,
+ tokens: str = "iamdb_1kwp_tokens_1000.txt",
lexicon: str = "iamdb_1kwp_lex_1000.txt",
data_dir: Optional[Union[str, Path]] = None,
use_words: bool = False,
@@ -35,4 +35,4 @@ class WordPiece:
self.max_len = max_len
def __call__(self, x: Tensor) -> Tensor:
- return self.mapping.emnist_to_wordpiece_indices(x)[:self.max_len]
+ return self.mapping.emnist_to_wordpiece_indices(x)[: self.max_len]