summaryrefslogtreecommitdiff
path: root/text_recognizer/datasets
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-23 21:55:42 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-23 21:55:42 +0100
commitae589fb3ffdbf6c4bb1ae35345f7a3665deeebc5 (patch)
tree1702f74c069679ebdd74a03892275c6eb3a80ffd /text_recognizer/datasets
parente3741de333a3a43a7968241b6eccaaac66dd7b20 (diff)
refactored emnist lines dataset
Diffstat (limited to 'text_recognizer/datasets')
-rw-r--r--text_recognizer/datasets/base_data_module.py2
-rw-r--r--text_recognizer/datasets/base_dataset.py8
-rw-r--r--text_recognizer/datasets/emnist.py12
-rw-r--r--text_recognizer/datasets/emnist_essentials.json2
-rw-r--r--text_recognizer/datasets/emnist_lines.py172
-rw-r--r--text_recognizer/datasets/sentence_generator.py30
6 files changed, 165 insertions, 61 deletions
diff --git a/text_recognizer/datasets/base_data_module.py b/text_recognizer/datasets/base_data_module.py
index 830b39b..f5e7300 100644
--- a/text_recognizer/datasets/base_data_module.py
+++ b/text_recognizer/datasets/base_data_module.py
@@ -46,7 +46,7 @@ class BaseDataModule(pl.LightningDataModule):
def setup(self, stage: str = None) -> None:
"""Split into train, val, test, and set dims.
-
+
Should assign `torch Dataset` objects to self.data_train, self.data_val, and
optionally self.data_test.
diff --git a/text_recognizer/datasets/base_dataset.py b/text_recognizer/datasets/base_dataset.py
index a004b8d..a9e9c24 100644
--- a/text_recognizer/datasets/base_dataset.py
+++ b/text_recognizer/datasets/base_dataset.py
@@ -61,13 +61,13 @@ def convert_strings_to_labels(
strings: Sequence[str], mapping: Dict[str, int], length: int
) -> Tensor:
"""
- Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <S> and </S> tokens,
- and padded wiht the <P> token.
+ Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <s> and </s> tokens,
+ and padded wiht the <p> token.
"""
- labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<P>"]
+ labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<p>"]
for i, string in enumerate(strings):
tokens = list(string)
- tokens = ["<S>", *tokens, "</S>"]
+ tokens = ["<s>", *tokens, "</s>"]
for j, token in enumerate(tokens):
labels[i, j] = mapping[token]
return labels
diff --git a/text_recognizer/datasets/emnist.py b/text_recognizer/datasets/emnist.py
index 7c208c4..66101b5 100644
--- a/text_recognizer/datasets/emnist.py
+++ b/text_recognizer/datasets/emnist.py
@@ -70,9 +70,11 @@ class EMNIST(BaseDataModule):
if stage == "fit" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self.x_train = f["x_train"][:]
- self.y_train = f["y_train"][:]
+ self.y_train = f["y_train"][:].squeeze().astype(int)
- dataset_train = BaseDataset(self.x_train, self.y_train, transform=self.transform)
+ dataset_train = BaseDataset(
+ self.x_train, self.y_train, transform=self.transform
+ )
train_size = int(self.train_fraction * len(dataset_train))
val_size = len(dataset_train) - train_size
self.data_train, self.data_val = random_split(
@@ -82,8 +84,10 @@ class EMNIST(BaseDataModule):
if stage == "test" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self.x_test = f["x_test"][:]
- self.y_test = f["y_test"][:]
- self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform)
+ self.y_test = f["y_test"][:].squeeze().astype(int)
+ self.data_test = BaseDataset(
+ self.x_test, self.y_test, transform=self.transform
+ )
def __repr__(self) -> str:
basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n"
diff --git a/text_recognizer/datasets/emnist_essentials.json b/text_recognizer/datasets/emnist_essentials.json
index 100b36a..3f46a73 100644
--- a/text_recognizer/datasets/emnist_essentials.json
+++ b/text_recognizer/datasets/emnist_essentials.json
@@ -1 +1 @@
-{"characters": ["<b>", "<s>", "</s>", "<p>", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} \ No newline at end of file
+{"characters": ["<b>", "<s>", "</s>", "<p>", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]}
diff --git a/text_recognizer/datasets/emnist_lines.py b/text_recognizer/datasets/emnist_lines.py
index ae23feb..9ebad22 100644
--- a/text_recognizer/datasets/emnist_lines.py
+++ b/text_recognizer/datasets/emnist_lines.py
@@ -1,16 +1,21 @@
"""Dataset of generated text from EMNIST characters."""
from collections import defaultdict
from pathlib import Path
-from typing import Dict, Sequence
+from typing import Callable, Dict, Tuple, Sequence
import h5py
from loguru import logger
import numpy as np
+from PIL import Image
import torch
from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
-from text_recognizer.datasets.base_dataset import BaseDataset
-from text_recognizer.datasets.base_data_module import BaseDataModule
+from text_recognizer.datasets.base_dataset import BaseDataset, convert_strings_to_labels
+from text_recognizer.datasets.base_data_module import (
+ BaseDataModule,
+ load_and_print_info,
+)
from text_recognizer.datasets.emnist import EMNIST
from text_recognizer.datasets.sentence_generator import SentenceGenerator
@@ -54,18 +59,23 @@ class EMNISTLines(BaseDataModule):
self.emnist = EMNIST()
self.mapping = self.emnist.mapping
- max_width = int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) + IMAGE_X_PADDING
-
- if max_width <= IMAGE_WIDTH:
- raise ValueError("max_width greater than IMAGE_WIDTH")
+ max_width = (
+ int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap))
+ + IMAGE_X_PADDING
+ )
+
+ if max_width >= IMAGE_WIDTH:
+ raise ValueError(
+ f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}"
+ )
self.dims = (
self.emnist.dims[0],
- self.emnist.dims[1],
- self.emnist.dims[2] * self.max_length,
+ IMAGE_HEIGHT,
+ IMAGE_WIDTH
)
- if self.max_length <= MAX_OUTPUT_LENGTH:
+ if self.max_length >= MAX_OUTPUT_LENGTH:
raise ValueError("max_length greater than MAX_OUTPUT_LENGTH")
self.output_dims = (MAX_OUTPUT_LENGTH, 1)
@@ -77,8 +87,11 @@ class EMNISTLines(BaseDataModule):
def data_filename(self) -> Path:
"""Return name of dataset."""
return (
- DATA_DIRNAME
- / f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5"
+ DATA_DIRNAME / (f"ml_{self.max_length}_"
+ f"o{self.min_overlap:f}_{self.max_overlap:f}_"
+ f"ntr{self.num_train}_"
+ f"ntv{self.num_val}_"
+ f"nte{self.num_test}.h5")
)
def prepare_data(self) -> None:
@@ -92,21 +105,28 @@ class EMNISTLines(BaseDataModule):
def setup(self, stage: str = None) -> None:
logger.info("EMNISTLinesDataset loading data from HDF5...")
if stage == "fit" or stage is None:
+ print(self.data_filename)
with h5py.File(self.data_filename, "r") as f:
x_train = f["x_train"][:]
y_train = torch.LongTensor(f["y_train"][:])
x_val = f["x_val"][:]
y_val = torch.LongTensor(f["y_val"][:])
- self.data_train = BaseDataset(x_train, y_train, transform=_get_transform(augment=self.augment))
- self.data_val = BaseDataset(x_val, y_val, transform=_get_transform(augment=self.augment))
+ self.data_train = BaseDataset(
+ x_train, y_train, transform=_get_transform(augment=self.augment)
+ )
+ self.data_val = BaseDataset(
+ x_val, y_val, transform=_get_transform(augment=self.augment)
+ )
if stage == "test" or stage is None:
with h5py.File(self.data_filename, "r") as f:
x_test = f["x_test"][:]
y_test = torch.LongTensor(f["y_test"][:])
- self.data_train = BaseDataset(x_test, y_test, transform=_get_transform(augment=False))
+ self.data_test = BaseDataset(
+ x_test, y_test, transform=_get_transform(augment=False)
+ )
def __repr__(self) -> str:
"""Return str about dataset."""
@@ -132,53 +152,129 @@ class EMNISTLines(BaseDataModule):
def _generate_data(self, split: str) -> None:
logger.info(f"EMNISTLines generating data for {split}...")
- sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract by 2 because start/end token
+ sentence_generator = SentenceGenerator(
+ self.max_length - 2
+ ) # Subtract by 2 because start/end token
emnist = self.emnist
emnist.prepare_data()
emnist.setup()
if split == "train":
- samples_by_char = _get_samples_by_char(emnist.x_train, emnist.y_train, emnist.mapping)
+ samples_by_char = _get_samples_by_char(
+ emnist.x_train, emnist.y_train, emnist.mapping
+ )
num = self.num_train
elif split == "val":
- samples_by_char = _get_samples_by_char(emnist.x_train, emnist.y_train, emnist.mapping)
+ samples_by_char = _get_samples_by_char(
+ emnist.x_train, emnist.y_train, emnist.mapping
+ )
num = self.num_val
- elif split == "test":
- samples_by_char = _get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping)
+ else:
+ samples_by_char = _get_samples_by_char(
+ emnist.x_test, emnist.y_test, emnist.mapping
+ )
num = self.num_test
DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
- with h5py.File(self.data_filename, "w") as f:
+ with h5py.File(self.data_filename, "a") as f:
x, y = _create_dataset_of_images(
- num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.dims
- )
- y = _convert_strings_to_labels(
- y,
- emnist.inverse_mapping,
- length=MAX_OUTPUT_LENGTH
- )
+ num,
+ samples_by_char,
+ sentence_generator,
+ self.min_overlap,
+ self.max_overlap,
+ self.dims,
+ )
+ y = convert_strings_to_labels(
+ y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH
+ )
f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
-def _get_samples_by_char(samples: np.ndarray, labels: np.ndarray, mapping: Dict) -> defaultdict:
+
+def _get_samples_by_char(
+ samples: np.ndarray, labels: np.ndarray, mapping: Dict
+) -> defaultdict:
samples_by_char = defaultdict(list)
for sample, label in zip(samples, labels):
samples_by_char[mapping[label]].append(sample)
return samples_by_char
-def _construct_image_from_string():
- pass
-
-
def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict):
- pass
-
-
-def _create_dataset_of_images(num_samples: int, samples_by_char: defaultdict, sentence_generator: SentenceGenerator, min_overlap: float, max_overlap: float, dims: Tuple) -> Tuple[torch.Tensor, torch.Tensor]:
+ null_image = torch.zeros((28, 28), dtype=torch.uint8)
+ sample_image_by_char = {}
+ for char in string:
+ if char in sample_image_by_char:
+ continue
+ samples = samples_by_char[char]
+ sample = samples[np.random.choice(len(samples))] if samples else null_image
+ sample_image_by_char[char] = sample.reshape(28, 28)
+ return [sample_image_by_char[char] for char in string]
+
+
+def _construct_image_from_string(
+ string: str,
+ samples_by_char: defaultdict,
+ min_overlap: float,
+ max_overlap: float,
+ width: int,
+) -> torch.Tensor:
+ overlap = np.random.uniform(min_overlap, max_overlap)
+ sampled_images = _select_letter_samples_for_string(string, samples_by_char)
+ N = len(sampled_images)
+ H, W = sampled_images[0].shape
+ next_overlap_width = W - int(overlap * W)
+ concatenated_image = torch.zeros((H, width), dtype=torch.uint8)
+ x = IMAGE_X_PADDING
+ for image in sampled_images:
+ concatenated_image[:, x : (x + W)] += image
+ x += next_overlap_width
+ return torch.minimum(torch.Tensor([255]), concatenated_image)
+
+
+def _create_dataset_of_images(
+ num_samples: int,
+ samples_by_char: defaultdict,
+ sentence_generator: SentenceGenerator,
+ min_overlap: float,
+ max_overlap: float,
+ dims: Tuple,
+) -> Tuple[torch.Tensor, torch.Tensor]:
images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2]))
labels = []
for n in range(num_samples):
label = sentence_generator.generate()
- crop = _construct_image_from_string()
+ crop = _construct_image_from_string(
+ label, samples_by_char, min_overlap, max_overlap, dims[-1]
+ )
+ height = crop.shape[0]
+ y = (IMAGE_HEIGHT - height) // 2
+ images[n, y : (y + height), :] = crop
+ labels.append(label)
+ return images, labels
+
+
+def _get_transform(augment: bool = False) -> Callable:
+ if not augment:
+ return transforms.Compose([transforms.ToTensor()])
+ return transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.ColorJitter(brightness=(0.5, 1.0)),
+ transforms.RandomAffine(
+ degrees=3,
+ translate=(0.0, 0.05),
+ scale=(0.4, 1.1),
+ shear=(-40, 50),
+ interpolation=InterpolationMode.BILINEAR,
+ fill=0,
+ ),
+ ]
+ )
+
+
+def generate_emnist_lines() -> None:
+ """Generates a synthetic handwritten dataset and displays info,"""
+ load_and_print_info(EMNISTLines)
diff --git a/text_recognizer/datasets/sentence_generator.py b/text_recognizer/datasets/sentence_generator.py
index dd76652..53b781c 100644
--- a/text_recognizer/datasets/sentence_generator.py
+++ b/text_recognizer/datasets/sentence_generator.py
@@ -11,7 +11,7 @@ import numpy as np
from text_recognizer.datasets.util import DATA_DIRNAME
-NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk"
+NLTK_DATA_DIRNAME = DATA_DIRNAME / "downloaded" / "nltk"
class SentenceGenerator:
@@ -47,18 +47,22 @@ class SentenceGenerator:
raise ValueError(
"Must provide max_length to this method or when making this object."
)
-
- index = np.random.randint(0, len(self.word_start_indices) - 1)
- start_index = self.word_start_indices[index]
- end_index_candidates = []
- for index in range(index + 1, len(self.word_start_indices)):
- if self.word_start_indices[index] - start_index > max_length:
- break
- end_index_candidates.append(self.word_start_indices[index])
- end_index = np.random.choice(end_index_candidates)
- sampled_text = self.corpus[start_index:end_index].strip()
- padding = "_" * (max_length - len(sampled_text))
- return sampled_text + padding
+
+ for _ in range(10):
+ try:
+ index = np.random.randint(0, len(self.word_start_indices) - 1)
+ start_index = self.word_start_indices[index]
+ end_index_candidates = []
+ for index in range(index + 1, len(self.word_start_indices)):
+ if self.word_start_indices[index] - start_index > max_length:
+ break
+ end_index_candidates.append(self.word_start_indices[index])
+ end_index = np.random.choice(end_index_candidates)
+ sampled_text = self.corpus[start_index:end_index].strip()
+ return sampled_text
+ except Exception:
+ pass
+ raise RuntimeError("Was not able to generate a valid string")
def brown_corpus() -> str: