summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-27 00:08:04 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-27 00:08:04 +0200
commit27ff7d113108e9cc51ddc5ff13b648b9c75fa865 (patch)
tree96b35c2f65978b8718665aaded3d29f00aaf43e2 /text_recognizer/data
parent3227735099f8acb37ffe658b8f04b6c308b64d23 (diff)
Add metadata
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/emnist.py48
-rw-r--r--text_recognizer/data/emnist_lines.py44
-rw-r--r--text_recognizer/data/iam.py27
-rw-r--r--text_recognizer/data/iam_lines.py38
-rw-r--r--text_recognizer/data/iam_paragraphs.py38
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py8
6 files changed, 84 insertions, 119 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 72cc80a..9c5727f 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -15,27 +15,15 @@ from text_recognizer.data.base_data_module import BaseDataModule, load_and_print
from text_recognizer.data.base_dataset import BaseDataset, split_dataset
from text_recognizer.data.transforms.load_transform import load_transform_from_file
from text_recognizer.data.utils.download_utils import download_dataset
-
-SEED = 4711
-NUM_SPECIAL_TOKENS = 4
-SAMPLE_TO_BALANCE = True
-
-RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist"
-METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
-DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist"
-PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist"
-PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
-ESSENTIALS_FILENAME = (
- Path(__file__).parents[0].resolve() / "mappings" / "emnist_essentials.json"
-)
+from text_recognizer.metadata import emnist as metadata
class EMNIST(BaseDataModule):
"""Lightning DataModule class for loading EMNIST dataset.
'The EMNIST dataset is a set of handwritten character digits derived from the NIST
- Special Database 19 and converted to a 28x28 pixel image format and dataset structure
- that directly matches the MNIST dataset.'
+ Special Database 19 and converted to a 28x28 pixel image format and dataset
+ structure that directly matches the MNIST dataset.'
From https://www.nist.gov/itl/iad/image-group/emnist-dataset
The data split we will use is
@@ -48,13 +36,13 @@ class EMNIST(BaseDataModule):
def prepare_data(self) -> None:
"""Downloads dataset if not present."""
- if not PROCESSED_DATA_FILENAME.exists():
+ if not metadata.PROCESSED_DATA_FILENAME.exists():
download_and_process_emnist()
def setup(self, stage: Optional[str] = None) -> None:
"""Loads the dataset specified by the stage."""
if stage == "fit" or stage is None:
- with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
+ with h5py.File(metadata.PROCESSED_DATA_FILENAME, "r") as f:
self.x_train = f["x_train"][:]
self.y_train = f["y_train"][:].squeeze().astype(int)
@@ -62,11 +50,11 @@ class EMNIST(BaseDataModule):
self.x_train, self.y_train, transform=self.transform
)
self.data_train, self.data_val = split_dataset(
- dataset_train, fraction=self.train_fraction, seed=SEED
+ dataset_train, fraction=self.train_fraction, seed=metadata.SEED
)
if stage == "test" or stage is None:
- with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
+ with h5py.File(metadata.PROCESSED_DATA_FILENAME, "r") as f:
self.x_test = f["x_test"][:]
self.y_test = f["y_test"][:].squeeze().astype(int)
self.data_test = BaseDataset(
@@ -100,9 +88,9 @@ class EMNIST(BaseDataModule):
def download_and_process_emnist() -> None:
"""Downloads and preprocesses EMNIST dataset."""
- metadata = toml.load(METADATA_FILENAME)
- download_dataset(metadata, DL_DATA_DIRNAME)
- _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME)
+ metadata_ = toml.load(metadata.METADATA_FILENAME)
+ download_dataset(metadata_, metadata.DL_DATA_DIRNAME)
+ _process_raw_dataset(metadata_["filename"], metadata.DL_DATA_DIRNAME)
def _process_raw_dataset(filename: str, dirname: Path) -> None:
@@ -122,20 +110,22 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
.reshape(-1, 28, 28)
.swapaxes(1, 2)
)
- y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
+ y_train = (
+ data["dataset"]["train"][0, 0]["labels"][0, 0] + metadata.NUM_SPECIAL_TOKENS
+ )
x_test = (
data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
)
- y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
+ y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + metadata.NUM_SPECIAL_TOKENS
- if SAMPLE_TO_BALANCE:
+ if metadata.SAMPLE_TO_BALANCE:
log.info("Balancing classes to reduce amount of data")
x_train, y_train = _sample_to_balance(x_train, y_train)
x_test, y_test = _sample_to_balance(x_test, y_test)
log.info("Saving to HDF5 in a compressed format...")
- PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
- with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
+ metadata.PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ with h5py.File(metadata.PROCESSED_DATA_FILENAME, "w") as f:
f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf")
f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf")
f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf")
@@ -146,7 +136,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
characters = _augment_emnist_characters(mapping.values())
essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])}
- with ESSENTIALS_FILENAME.open(mode="w") as f:
+ with metadata.ESSENTIALS_FILENAME.open(mode="w") as f:
json.dump(essentials, f)
log.info("Cleaning up...")
@@ -156,7 +146,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Balances the dataset by taking the mean number of instances per class."""
- np.random.seed(SEED)
+ np.random.seed(metadata.SEED)
num_to_sample = int(np.bincount(y.flatten()).mean())
all_sampled_indices = []
for label in np.unique(y.flatten()):
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index c36132e..63c9f22 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -1,12 +1,11 @@
"""Dataset of generated text from EMNIST characters."""
from collections import defaultdict
from pathlib import Path
-from typing import Callable, DefaultDict, List, Optional, Tuple, Type
+from typing import Callable, DefaultDict, List, Optional, Tuple
import h5py
import numpy as np
import torch
-import torchvision.transforms as T
from loguru import logger as log
from torch import Tensor
@@ -16,17 +15,7 @@ from text_recognizer.data.emnist import EMNIST
from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
from text_recognizer.data.utils.sentence_generator import SentenceGenerator
-
-DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines"
-ESSENTIALS_FILENAME = (
- Path(__file__).parents[0].resolve() / "mappings" / "emnist_lines_essentials.json"
-)
-
-SEED = 4711
-IMAGE_HEIGHT = 56
-IMAGE_WIDTH = 1024
-IMAGE_X_PADDING = 28
-MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
+from text_recognizer.metadata import emnist_lines as metadata
class EMNISTLines(BaseDataModule):
@@ -70,25 +59,25 @@ class EMNISTLines(BaseDataModule):
self.emnist = EMNIST()
max_width = (
int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap))
- + IMAGE_X_PADDING
+ + metadata.IMAGE_X_PADDING
)
- if max_width >= IMAGE_WIDTH:
+ if max_width >= metadata.IMAGE_WIDTH:
raise ValueError(
- f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}"
+ f"max_width {max_width} greater than IMAGE_WIDTH {metadata.IMAGE_WIDTH}"
)
- self.dims = (self.emnist.dims[0], IMAGE_HEIGHT, IMAGE_WIDTH)
+ self.dims = (self.emnist.dims[0], metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH)
- if self.max_length >= MAX_OUTPUT_LENGTH:
+ if self.max_length >= metadata.MAX_OUTPUT_LENGTH:
raise ValueError("max_length greater than MAX_OUTPUT_LENGTH")
- self.output_dims = (MAX_OUTPUT_LENGTH, 1)
+ self.output_dims = (metadata.MAX_OUTPUT_LENGTH, 1)
@property
def data_filename(self) -> Path:
"""Return name of dataset."""
- return DATA_DIRNAME / (
+ return metadata.DATA_DIRNAME / (
f"ml_{self.max_length}_"
f"o{self.min_overlap:f}_{self.max_overlap:f}_"
f"ntr{self.num_train}_"
@@ -100,7 +89,7 @@ class EMNISTLines(BaseDataModule):
"""Prepare the dataset."""
if self.data_filename.exists():
return
- np.random.seed(SEED)
+ np.random.seed(metadata.SEED)
self._generate_data("train")
self._generate_data("val")
self._generate_data("test")
@@ -146,7 +135,8 @@ class EMNISTLines(BaseDataModule):
f"{len(self.data_train)}, "
f"{len(self.data_val)}, "
f"{len(self.data_test)}\n"
- f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
+ "Batch x stats: "
+ f"{(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
)
return basic + data
@@ -177,7 +167,7 @@ class EMNISTLines(BaseDataModule):
)
num = self.num_test
- DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ metadata.PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(self.data_filename, "a") as f:
x, y = _create_dataset_of_images(
num,
@@ -188,7 +178,7 @@ class EMNISTLines(BaseDataModule):
self.dims,
)
y = convert_strings_to_labels(
- y, self.mapping.inverse_mapping, length=MAX_OUTPUT_LENGTH
+ y, self.mapping.inverse_mapping, length=metadata.MAX_OUTPUT_LENGTH
)
f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
@@ -229,7 +219,7 @@ def _construct_image_from_string(
H, W = sampled_images[0].shape
next_overlap_width = W - int(overlap * W)
concatenated_image = torch.zeros((H, width), dtype=torch.uint8)
- x = IMAGE_X_PADDING
+ x = metadata.IMAGE_X_PADDING
for image in sampled_images:
concatenated_image[:, x : (x + W)] += image
x += next_overlap_width
@@ -244,7 +234,7 @@ def _create_dataset_of_images(
max_overlap: float,
dims: Tuple,
) -> Tuple[Tensor, Tensor]:
- images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2]))
+ images = torch.zeros((num_samples, metadata.IMAGE_HEIGHT, dims[2]))
labels = []
for n in range(num_samples):
label = sentence_generator.generate()
@@ -252,7 +242,7 @@ def _create_dataset_of_images(
label, samples_by_char, min_overlap, max_overlap, dims[-1]
)
height = crop.shape[0]
- y = (IMAGE_HEIGHT - height) // 2
+ y = (metadata.IMAGE_HEIGHT - height) // 2
images[n, y : (y + height), :] = crop
labels.append(label)
return images, labels
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index e3baf88..c20b50b 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -15,14 +15,7 @@ from loguru import logger as log
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.utils.download_utils import download_dataset
-
-RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "iam"
-METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
-DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "iam"
-EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb"
-
-DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
-LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates.
+from text_recognizer.metadata import iam as metadata
class IAM(BaseDataModule):
@@ -44,24 +37,24 @@ class IAM(BaseDataModule):
def __init__(self) -> None:
super().__init__()
- self.metadata: Dict = toml.load(METADATA_FILENAME)
+ self.metadata: Dict = toml.load(metadata.METADATA_FILENAME)
def prepare_data(self) -> None:
"""Prepares the IAM dataset."""
if self.xml_filenames:
return
- filename = download_dataset(self.metadata, DL_DATA_DIRNAME)
- _extract_raw_dataset(filename, DL_DATA_DIRNAME)
+ filename = download_dataset(self.metadata, metadata.DL_DATA_DIRNAME)
+ _extract_raw_dataset(filename, metadata.DL_DATA_DIRNAME)
@property
def xml_filenames(self) -> List[Path]:
"""Returns the xml filenames."""
- return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml"))
+ return list((metadata.EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml"))
@property
def form_filenames(self) -> List[Path]:
"""Returns the form filenames."""
- return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg"))
+ return list((metadata.EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg"))
@property
def form_filenames_by_id(self) -> Dict[str, Path]:
@@ -133,10 +126,10 @@ def _get_line_region_from_xml_file(xml_line: Any) -> Dict[str, int]:
x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements]
y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements]
return {
- "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
- "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
- "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
- "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ "x1": min(x1s) // metadata.DOWNSAMPLE_FACTOR - metadata.LINE_REGION_PADDING,
+ "y1": min(y1s) // metadata.DOWNSAMPLE_FACTOR - metadata.LINE_REGION_PADDING,
+ "x2": max(x2s) // metadata.DOWNSAMPLE_FACTOR + metadata.LINE_REGION_PADDING,
+ "y2": max(y2s) // metadata.DOWNSAMPLE_FACTOR + metadata.LINE_REGION_PADDING,
}
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index a55ff1c..3bb189c 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -22,16 +22,10 @@ from text_recognizer.data.iam import IAM
from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
from text_recognizer.data.utils import image_utils
+from text_recognizer.metadata import iam_lines as metadata
ImageFile.LOAD_TRUNCATED_IMAGES = True
-SEED = 4711
-PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_lines"
-IMAGE_HEIGHT = 56
-IMAGE_WIDTH = 1024
-MAX_LABEL_LENGTH = 89
-MAX_WORD_PIECE_LENGTH = 72
-
class IAMLines(BaseDataModule):
"""IAM handwritten lines dataset."""
@@ -57,12 +51,12 @@ class IAMLines(BaseDataModule):
num_workers,
pin_memory,
)
- self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
- self.output_dims = (MAX_LABEL_LENGTH, 1)
+ self.dims = (1, metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH)
+ self.output_dims = (metadata.MAX_LABEL_LENGTH, 1)
def prepare_data(self) -> None:
"""Creates the IAM lines dataset if not existing."""
- if PROCESSED_DATA_DIRNAME.exists():
+ if metadata.PROCESSED_DATA_DIRNAME.exists():
return
log.info("Cropping IAM lines regions...")
@@ -76,24 +70,30 @@ class IAMLines(BaseDataModule):
log.info("Saving images, labels, and statistics...")
save_images_and_labels(
- crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME
+ crops_train, labels_train, "train", metadata.PROCESSED_DATA_DIRNAME
+ )
+ save_images_and_labels(
+ crops_test, labels_test, "test", metadata.PROCESSED_DATA_DIRNAME
)
- save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME)
- with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="w") as f:
+ with (metadata.PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(
+ mode="w"
+ ) as f:
f.write(str(aspect_ratios.max()))
def setup(self, stage: str = None) -> None:
"""Load data for training/testing."""
- with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="r") as f:
+ with (metadata.PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(
+ mode="r"
+ ) as f:
max_aspect_ratio = float(f.read())
- image_width = int(IMAGE_HEIGHT * max_aspect_ratio)
- if image_width >= IMAGE_WIDTH:
+ image_width = int(metadata.IMAGE_HEIGHT * max_aspect_ratio)
+ if image_width >= metadata.IMAGE_WIDTH:
raise ValueError("image_width equal or greater than IMAGE_WIDTH")
if stage == "fit" or stage is None:
x_train, labels_train = load_line_crops_and_labels(
- "train", PROCESSED_DATA_DIRNAME
+ "train", metadata.PROCESSED_DATA_DIRNAME
)
if self.output_dims[0] < max([len(labels) for labels in labels_train]) + 2:
raise ValueError("Target length longer than max output length.")
@@ -109,12 +109,12 @@ class IAMLines(BaseDataModule):
)
self.data_train, self.data_val = split_dataset(
- dataset=data_train, fraction=self.train_fraction, seed=SEED
+ dataset=data_train, fraction=self.train_fraction, seed=metadata.SEED
)
if stage == "test" or stage is None:
x_test, labels_test = load_line_crops_and_labels(
- "test", PROCESSED_DATA_DIRNAME
+ "test", metadata.PROCESSED_DATA_DIRNAME
)
if self.output_dims[0] < max([len(labels) for labels in labels_test]) + 2:
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index c7d5229..eec1b1f 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -18,17 +18,7 @@ from text_recognizer.data.base_dataset import (
from text_recognizer.data.iam import IAM
from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
-
-PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs"
-
-NEW_LINE_TOKEN = "\n"
-
-SEED = 4711
-IMAGE_SCALE_FACTOR = 2
-IMAGE_HEIGHT = 1152 // IMAGE_SCALE_FACTOR
-IMAGE_WIDTH = 1280 // IMAGE_SCALE_FACTOR
-MAX_LABEL_LENGTH = 682
-MAX_WORD_PIECE_LENGTH = 451
+from text_recognizer.metadata import iam_paragraphs as metadata
class IAMParagraphs(BaseDataModule):
@@ -55,17 +45,17 @@ class IAMParagraphs(BaseDataModule):
num_workers,
pin_memory,
)
- self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
- self.output_dims = (MAX_LABEL_LENGTH, 1)
+ self.dims = (1, metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH)
+ self.output_dims = (metadata.MAX_LABEL_LENGTH, 1)
def prepare_data(self) -> None:
"""Create data for training/testing."""
- if PROCESSED_DATA_DIRNAME.exists():
+ if metadata.PROCESSED_DATA_DIRNAME.exists():
return
log.info("Cropping IAM paragraph regions and saving them along with labels...")
- iam = IAM(mapping=EmnistMapping(extra_symbols={NEW_LINE_TOKEN}))
+ iam = IAM(mapping=EmnistMapping(extra_symbols={metadata.NEW_LINE_TOKEN}))
iam.prepare_data()
properties = {}
@@ -84,7 +74,7 @@ class IAMParagraphs(BaseDataModule):
}
)
- with (PROCESSED_DATA_DIRNAME / "_properties.json").open("w") as f:
+ with (metadata.PROCESSED_DATA_DIRNAME / "_properties.json").open("w") as f:
json.dump(properties, f, indent=4)
def setup(self, stage: str = None) -> None:
@@ -94,7 +84,7 @@ class IAMParagraphs(BaseDataModule):
split: str, transform: T.Compose, target_transform: T.Compose
) -> BaseDataset:
crops, labels = _load_processed_crops_and_labels(split)
- data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
+ data = [resize_image(crop, metadata.IMAGE_SCALE_FACTOR) for crop in crops]
targets = convert_strings_to_labels(
strings=labels,
mapping=self.mapping.inverse_mapping,
@@ -117,7 +107,7 @@ class IAMParagraphs(BaseDataModule):
target_transform=self.target_transform,
)
self.data_train, self.data_val = split_dataset(
- dataset=data_train, fraction=self.train_fraction, seed=SEED
+ dataset=data_train, fraction=self.train_fraction, seed=metadata.SEED
)
if stage == "test" or stage is None:
@@ -162,7 +152,7 @@ class IAMParagraphs(BaseDataModule):
def get_dataset_properties() -> Dict:
"""Return properties describing the overall dataset."""
- with (PROCESSED_DATA_DIRNAME / "_properties.json").open("r") as f:
+ with (metadata.PROCESSED_DATA_DIRNAME / "_properties.json").open("r") as f:
properties = json.load(f)
def _get_property_values(key: str) -> List:
@@ -193,7 +183,7 @@ def _validate_data_dims(
"""Validates input and output dimensions against the properties of the dataset."""
properties = get_dataset_properties()
- max_image_shape = properties["crop_shape"]["max"] / IMAGE_SCALE_FACTOR
+ max_image_shape = properties["crop_shape"]["max"] / metadata.IMAGE_SCALE_FACTOR
if (
input_dims is not None
and input_dims[1] < max_image_shape[0]
@@ -246,7 +236,7 @@ def _get_paragraph_crops_and_labels(
lines = iam.line_strings_by_id[id_]
crops[id_] = image.crop(paragraph_box)
- labels[id_] = NEW_LINE_TOKEN.join(lines)
+ labels[id_] = metadata.NEW_LINE_TOKEN.join(lines)
if len(crops) != len(labels):
raise ValueError(f"Crops ({len(crops)}) does not match labels ({len(labels)})")
@@ -258,7 +248,7 @@ def _save_crops_and_labels(
crops: Dict[str, Image.Image], labels: Dict[str, str], split: str
) -> None:
"""Save crops, labels, and shapes of crops of a split."""
- (PROCESSED_DATA_DIRNAME / split).mkdir(parents=True, exist_ok=True)
+ (metadata.PROCESSED_DATA_DIRNAME / split).mkdir(parents=True, exist_ok=True)
with _labels_filename(split).open("w") as f:
json.dump(labels, f, indent=4)
@@ -289,12 +279,12 @@ def _load_processed_crops_and_labels(
def _labels_filename(split: str) -> Path:
"""Return filename of processed labels."""
- return PROCESSED_DATA_DIRNAME / split / "_labels.json"
+ return metadata.PROCESSED_DATA_DIRNAME / split / "_labels.json"
def _crop_filename(id: str, split: str) -> Path:
"""Return filename of processed crop."""
- return PROCESSED_DATA_DIRNAME / split / f"{id}.png"
+ return metadata.PROCESSED_DATA_DIRNAME / split / f"{id}.png"
def _num_lines(label: str) -> int:
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 5e66499..52ed398 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -6,7 +6,7 @@ import numpy as np
from loguru import logger as log
from PIL import Image
-from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
+from text_recognizer.data.base_data_module import load_and_print_info
from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels
from text_recognizer.data.iam import IAM
from text_recognizer.data.iam_lines import (
@@ -23,9 +23,10 @@ from text_recognizer.data.iam_paragraphs import (
)
from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
+from text_recognizer.metadata import shared as metadata
PROCESSED_DATA_DIRNAME = (
- BaseDataModule.data_dirname() / "processed" / "iam_synthetic_paragraphs"
+ metadata.DATA_DIRNAME / "processed" / "iam_synthetic_paragraphs"
)
@@ -117,7 +118,8 @@ class IAMSyntheticParagraphs(IAMParagraphs):
x = x[0] if isinstance(x, list) else x
data = (
f"Train/val/test sizes: {len(self.data_train)}, 0, 0\n"
- f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
+ "Train Batch x stats: "
+ f"{(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
)
return basic + data