summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_paragraphs.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 18:18:48 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 18:18:48 +0200
commitbd4bd443f339e95007bfdabf3e060db720f4d4b9 (patch)
treee55cb3744904f7c2a0348b100c7e92a65e538a16 /text_recognizer/data/iam_paragraphs.py
parent75801019981492eedf9280cb352eea3d8e99b65f (diff)
Training working, multiple bug fixes
Diffstat (limited to 'text_recognizer/data/iam_paragraphs.py')
-rw-r--r--text_recognizer/data/iam_paragraphs.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 6189f7d..11f899f 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -17,7 +17,7 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data.transforms import WordPiece
@@ -50,11 +50,9 @@ class IAMParagraphs(BaseDataModule):
if PROCESSED_DATA_DIRNAME.exists():
return
- log.info(
- "Cropping IAM paragraph regions and saving them along with labels..."
- )
+ log.info("Cropping IAM paragraph regions and saving them along with labels...")
- iam = IAM(mapping=EmnistMapping())
+ iam = IAM(mapping=EmnistMapping(extra_symbols={NEW_LINE_TOKEN,}))
iam.prepare_data()
properties = {}
@@ -83,7 +81,9 @@ class IAMParagraphs(BaseDataModule):
crops, labels = _load_processed_crops_and_labels(split)
data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
targets = convert_strings_to_labels(
- strings=labels, mapping=self.mapping.inverse_mapping, length=self.output_dims[0]
+ strings=labels,
+ mapping=self.mapping.inverse_mapping,
+ length=self.output_dims[0],
)
return BaseDataset(
data,