summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist_lines.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/emnist_lines.py')
-rw-r--r--text_recognizer/data/emnist_lines.py44
1 files changed, 17 insertions, 27 deletions
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index c36132e..63c9f22 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -1,12 +1,11 @@
"""Dataset of generated text from EMNIST characters."""
from collections import defaultdict
from pathlib import Path
-from typing import Callable, DefaultDict, List, Optional, Tuple, Type
+from typing import Callable, DefaultDict, List, Optional, Tuple
import h5py
import numpy as np
import torch
-import torchvision.transforms as T
from loguru import logger as log
from torch import Tensor
@@ -16,17 +15,7 @@ from text_recognizer.data.emnist import EMNIST
from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
from text_recognizer.data.utils.sentence_generator import SentenceGenerator
-
-DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines"
-ESSENTIALS_FILENAME = (
- Path(__file__).parents[0].resolve() / "mappings" / "emnist_lines_essentials.json"
-)
-
-SEED = 4711
-IMAGE_HEIGHT = 56
-IMAGE_WIDTH = 1024
-IMAGE_X_PADDING = 28
-MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
+from text_recognizer.metadata import emnist_lines as metadata
class EMNISTLines(BaseDataModule):
@@ -70,25 +59,25 @@ class EMNISTLines(BaseDataModule):
self.emnist = EMNIST()
max_width = (
int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap))
- + IMAGE_X_PADDING
+ + metadata.IMAGE_X_PADDING
)
- if max_width >= IMAGE_WIDTH:
+ if max_width >= metadata.IMAGE_WIDTH:
raise ValueError(
- f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}"
+ f"max_width {max_width} greater than IMAGE_WIDTH {metadata.IMAGE_WIDTH}"
)
- self.dims = (self.emnist.dims[0], IMAGE_HEIGHT, IMAGE_WIDTH)
+ self.dims = (self.emnist.dims[0], metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH)
- if self.max_length >= MAX_OUTPUT_LENGTH:
+ if self.max_length >= metadata.MAX_OUTPUT_LENGTH:
raise ValueError("max_length greater than MAX_OUTPUT_LENGTH")
- self.output_dims = (MAX_OUTPUT_LENGTH, 1)
+ self.output_dims = (metadata.MAX_OUTPUT_LENGTH, 1)
@property
def data_filename(self) -> Path:
"""Return name of dataset."""
- return DATA_DIRNAME / (
+ return metadata.DATA_DIRNAME / (
f"ml_{self.max_length}_"
f"o{self.min_overlap:f}_{self.max_overlap:f}_"
f"ntr{self.num_train}_"
@@ -100,7 +89,7 @@ class EMNISTLines(BaseDataModule):
"""Prepare the dataset."""
if self.data_filename.exists():
return
- np.random.seed(SEED)
+ np.random.seed(metadata.SEED)
self._generate_data("train")
self._generate_data("val")
self._generate_data("test")
@@ -146,7 +135,8 @@ class EMNISTLines(BaseDataModule):
f"{len(self.data_train)}, "
f"{len(self.data_val)}, "
f"{len(self.data_test)}\n"
- f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
+ "Batch x stats: "
+ f"{(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
)
return basic + data
@@ -177,7 +167,7 @@ class EMNISTLines(BaseDataModule):
)
num = self.num_test
- DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ metadata.PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(self.data_filename, "a") as f:
x, y = _create_dataset_of_images(
num,
@@ -188,7 +178,7 @@ class EMNISTLines(BaseDataModule):
self.dims,
)
y = convert_strings_to_labels(
- y, self.mapping.inverse_mapping, length=MAX_OUTPUT_LENGTH
+ y, self.mapping.inverse_mapping, length=metadata.MAX_OUTPUT_LENGTH
)
f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
@@ -229,7 +219,7 @@ def _construct_image_from_string(
H, W = sampled_images[0].shape
next_overlap_width = W - int(overlap * W)
concatenated_image = torch.zeros((H, width), dtype=torch.uint8)
- x = IMAGE_X_PADDING
+ x = metadata.IMAGE_X_PADDING
for image in sampled_images:
concatenated_image[:, x : (x + W)] += image
x += next_overlap_width
@@ -244,7 +234,7 @@ def _create_dataset_of_images(
max_overlap: float,
dims: Tuple,
) -> Tuple[Tensor, Tensor]:
- images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2]))
+ images = torch.zeros((num_samples, metadata.IMAGE_HEIGHT, dims[2]))
labels = []
for n in range(num_samples):
label = sentence_generator.generate()
@@ -252,7 +242,7 @@ def _create_dataset_of_images(
label, samples_by_char, min_overlap, max_overlap, dims[-1]
)
height = crop.shape[0]
- y = (IMAGE_HEIGHT - height) // 2
+ y = (metadata.IMAGE_HEIGHT - height) // 2
images[n, y : (y + height), :] = crop
labels.append(label)
return images, labels