From 07dd14116fe1d8148fb614b160245287533620fc Mon Sep 17 00:00:00 2001
From: aktersnurra <gustaf.rydholm@gmail.com>
Date: Mon, 3 Aug 2020 23:33:34 +0200
Subject: Working Emnist lines dataset.

---
 src/text_recognizer/character_predictor.py         |   5 +-
 src/text_recognizer/datasets/__init__.py           |  24 +-
 src/text_recognizer/datasets/emnist_dataset.py     | 279 +++++++++++++-----
 .../datasets/emnist_lines_dataset.py               | 326 +++++++++++++++++++++
 src/text_recognizer/datasets/sentence_generator.py |  81 +++++
 src/text_recognizer/datasets/util.py               |  11 +
 src/text_recognizer/models/base.py                 |  84 ++++--
 src/text_recognizer/models/character_model.py      |  32 +-
 .../tests/test_character_predictor.py              |  14 +-
 .../weights/CharacterModel_Emnist_LeNet_weights.pt | Bin 14483400 -> 14485305 bytes
 .../weights/CharacterModel_Emnist_MLP_weights.pt   | Bin 1702233 -> 1704096 bytes
 11 files changed, 724 insertions(+), 132 deletions(-)
 create mode 100644 src/text_recognizer/datasets/emnist_lines_dataset.py
 create mode 100644 src/text_recognizer/datasets/sentence_generator.py
 create mode 100644 src/text_recognizer/datasets/util.py

(limited to 'src/text_recognizer')

diff --git a/src/text_recognizer/character_predictor.py b/src/text_recognizer/character_predictor.py
index a773f36..b733a53 100644
--- a/src/text_recognizer/character_predictor.py
+++ b/src/text_recognizer/character_predictor.py
@@ -11,10 +11,9 @@ from text_recognizer.util import read_image
 class CharacterPredictor:
     """Recognizes the character in handwritten character images."""
 
-    def __init__(self, network_fn: Type[nn.Module], network_args: Dict) -> None:
+    def __init__(self, network_fn: Type[nn.Module]) -> None:
         """Intializes the CharacterModel and load the pretrained weights."""
-        self.model = CharacterModel(network_fn=network_fn, network_args=network_args)
-        self.model.load_weights()
+        self.model = CharacterModel(network_fn=network_fn)
         self.model.eval()
 
     def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]:
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index 795be90..bfa6a02 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -1,4 +1,24 @@
 """Dataset modules."""
-from .emnist_dataset import EmnistDataLoader
+from .emnist_dataset import (
+    DATA_DIRNAME,
+    EmnistDataLoaders,
+    EmnistDataset,
+)
+from .emnist_lines_dataset import (
+    construct_image_from_string,
+    EmnistLinesDataset,
+    get_samples_by_character,
+)
+from .sentence_generator import SentenceGenerator
+from .util import Transpose
 
-__all__ = ["EmnistDataLoader"]
+__all__ = [
+    "construct_image_from_string",
+    "DATA_DIRNAME",
+    "EmnistDataset",
+    "EmnistDataLoaders",
+    "EmnistLinesDataset",
+    "get_samples_by_character",
+    "SentenceGenerator",
+    "Transpose",
+]
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index b92b57d..525df95 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -1,29 +1,23 @@
-"""Fetches a PyTorch DataLoader with the EMNIST dataset."""
+"""Emnist dataset: black and white images of handwritten characters (Aa-Zz) and digits (0-9)."""
 
 import json
 from pathlib import Path
-from typing import Callable, Dict, List, Optional, Type
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
 
 from loguru import logger
 import numpy as np
 from PIL import Image
-from torch.utils.data import DataLoader
+import torch
+from torch.utils.data import DataLoader, Dataset
 from torchvision.datasets import EMNIST
-from torchvision.transforms import Compose, ToTensor
+from torchvision.transforms import Compose, Normalize, ToTensor
 
+from text_recognizer.datasets.util import Transpose
 
 DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
 ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
 
 
-class Transpose:
-    """Transposes the EMNIST image to the correct orientation."""
-
-    def __call__(self, image: Image) -> np.ndarray:
-        """Swaps axis."""
-        return np.array(image).swapaxes(0, 1)
-
-
 def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None:
     """Extract and saves EMNIST essentials."""
     labels = emnsit_dataset.classes
@@ -45,14 +39,187 @@ def download_emnist() -> None:
     save_emnist_essentials(dataset)
 
 
-def load_emnist_mapping() -> Dict[int, str]:
+def _load_emnist_essentials() -> Dict:
     """Load the EMNIST mapping."""
     with open(str(ESSENTIALS_FILENAME)) as f:
         essentials = json.load(f)
-    return dict(essentials["mapping"])
+    return essentials
+
+
+def _augment_emnist_mapping(mapping: Dict) -> Dict:
+    """Augment the mapping with extra symbols."""
+    # Extra symbols in IAM dataset
+    extra_symbols = [
+        " ",
+        "!",
+        '"',
+        "#",
+        "&",
+        "'",
+        "(",
+        ")",
+        "*",
+        "+",
+        ",",
+        "-",
+        ".",
+        "/",
+        ":",
+        ";",
+        "?",
+    ]
+
+    # padding symbol
+    extra_symbols.append("_")
+
+    max_key = max(mapping.keys())
+    extra_mapping = {}
+    for i, symbol in enumerate(extra_symbols):
+        extra_mapping[max_key + 1 + i] = symbol
+
+    return {**mapping, **extra_mapping}
+
+
+class EmnistDataset(Dataset):
+    """This is a class for resampling and subsampling the PyTorch EMNIST dataset."""
+
+    def __init__(
+        self,
+        train: bool = False,
+        sample_to_balance: bool = False,
+        subsample_fraction: float = None,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        seed: int = 4711,
+    ) -> None:
+        """Loads the dataset and the mappings.
+
+        Args:
+            train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to
+                False.
+            sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False.
+            subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
+            transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
+            target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
+            seed (int): Seed number. Defaults to 4711.
+
+        Raises:
+            ValueError: If subsample_fraction is not None and outside the range (0, 1).
+
+        """
+
+        self.train = train
+        self.sample_to_balance = sample_to_balance
+        if subsample_fraction is not None:
+            if not 0.0 < subsample_fraction < 1.0:
+                raise ValueError("The subsample fraction must be in (0, 1).")
+        self.subsample_fraction = subsample_fraction
+        self.transform = transform
+        if self.transform is None:
+            self.transform = Compose([Transpose(), ToTensor()])
+
+        self.target_transform = target_transform
+        self.seed = seed
+
+        # Load dataset infromation.
+        essentials = _load_emnist_essentials()
+        self.mapping = _augment_emnist_mapping(dict(essentials["mapping"]))
+        self.inverse_mapping = {v: k for k, v in self.mapping.items()}
+        self.num_classes = len(self.mapping)
+        self.input_shape = essentials["input_shape"]
+
+        # Placeholders
+        self.data = None
+        self.targets = None
+
+    def __len__(self) -> int:
+        """Returns the length of the dataset."""
+        return len(self.data)
+
+    def __getitem__(
+        self, index: Union[int, torch.Tensor]
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Fetches samples from the dataset.
+
+        Args:
+            index (Union[int, torch.Tensor]): The indices of the samples to fetch.
+
+        Returns:
+            Tuple[torch.Tensor, torch.Tensor]: Data target tuple.
+
+        """
+        if torch.is_tensor(index):
+            index = index.tolist()
+
+        data = self.data[index]
+        targets = self.targets[index]
+
+        if self.transform:
+            data = self.transform(data)
+
+        if self.target_transform:
+            targets = self.target_transform(targets)
+
+        return data, targets
+
+    def __repr__(self) -> str:
+        """Returns information about the dataset."""
+        return (
+            "EMNIST Dataset\n"
+            f"Num classes: {self.num_classes}\n"
+            f"Mapping: {self.mapping}\n"
+            f"Input shape: {self.input_shape}\n"
+        )
+
+    def _sample_to_balance(self) -> None:
+        """Because the dataset is not balanced, we take at most the mean number of instances per class."""
+        np.random.seed(self.seed)
+        x = self.data
+        y = self.targets
+        num_to_sample = int(np.bincount(y.flatten()).mean())
+        all_sampled_indices = []
+        for label in np.unique(y.flatten()):
+            inds = np.where(y == label)[0]
+            sampled_indices = np.unique(np.random.choice(inds, num_to_sample))
+            all_sampled_indices.append(sampled_indices)
+        indices = np.concatenate(all_sampled_indices)
+        x_sampled = x[indices]
+        y_sampled = y[indices]
+        self.data = x_sampled
+        self.targets = y_sampled
+
+    def _subsample(self) -> None:
+        """Subsamples the dataset to the specified fraction."""
+        x = self.data
+        y = self.targets
+        num_samples = int(x.shape[0] * self.subsample_fraction)
+        x_sampled = x[:num_samples]
+        y_sampled = y[:num_samples]
+        self.data = x_sampled
+        self.targets = y_sampled
+
+    def load_emnist_dataset(self) -> None:
+        """Fetch the EMNIST dataset."""
+        dataset = EMNIST(
+            root=DATA_DIRNAME,
+            split="byclass",
+            train=self.train,
+            download=False,
+            transform=None,
+            target_transform=None,
+        )
+
+        self.data = dataset.data
+        self.targets = dataset.targets
+
+        if self.sample_to_balance:
+            self._sample_to_balance()
+
+        if self.subsample_fraction is not None:
+            self._subsample()
 
 
-class EmnistDataLoader:
+class EmnistDataLoaders:
     """Class for Emnist DataLoaders."""
 
     def __init__(
@@ -68,7 +235,7 @@ class EmnistDataLoader:
         cuda: bool = True,
         seed: int = 4711,
     ) -> None:
-        """Fetches DataLoaders.
+        """Fetches DataLoaders for given split(s).
 
         Args:
             splits (List[str]): One or both of the dataset splits "train" and "val".
@@ -88,13 +255,17 @@ class EmnistDataLoader:
                 them. Defaults to True.
             seed (int): Seed for sampling.
 
+        Raises:
+            ValueError: If subsample_fraction is not None and outside the range (0, 1).
+
         """
         self.splits = splits
         self.sample_to_balance = sample_to_balance
+
         if subsample_fraction is not None:
-            assert (
-                0.0 < subsample_fraction < 1.0
-            ), " The subsample fraction must be in (0, 1)."
+            if not 0.0 < subsample_fraction < 1.0:
+                raise ValueError("The subsample fraction must be in (0, 1).")
+
         self.subsample_fraction = subsample_fraction
         self.transform = transform
         self.target_transform = target_transform
@@ -105,6 +276,10 @@ class EmnistDataLoader:
         self.seed = seed
         self._data_loaders = self._fetch_emnist_data_loaders()
 
+    def __repr__(self) -> str:
+        """Returns information about the dataset."""
+        return self._data_loaders[self.splits[0]].dataset.__repr__()
+
     @property
     def __name__(self) -> str:
         """Returns the name of the dataset."""
@@ -128,59 +303,6 @@ class EmnistDataLoader:
         except KeyError:
             raise ValueError(f"Split {split} does not exist.")
 
-    def _sample_to_balance(self, dataset: type = EMNIST) -> EMNIST:
-        """Because the dataset is not balanced, we take at most the mean number of instances per class."""
-        np.random.seed(self.seed)
-        x = dataset.data
-        y = dataset.targets
-        num_to_sample = int(np.bincount(y.flatten()).mean())
-        all_sampled_indices = []
-        for label in np.unique(y.flatten()):
-            inds = np.where(y == label)[0]
-            sampled_indices = np.unique(np.random.choice(inds, num_to_sample))
-            all_sampled_indices.append(sampled_indices)
-        indices = np.concatenate(all_sampled_indices)
-        x_sampled = x[indices]
-        y_sampled = y[indices]
-        dataset.data = x_sampled
-        dataset.targets = y_sampled
-
-        return dataset
-
-    def _subsample(self, dataset: type = EMNIST) -> EMNIST:
-        """Subsamples the dataset to the specified fraction."""
-        x = dataset.data
-        y = dataset.targets
-        num_samples = int(x.shape[0] * self.subsample_fraction)
-        x_sampled = x[:num_samples]
-        y_sampled = y[:num_samples]
-        dataset.data = x_sampled
-        dataset.targets = y_sampled
-
-        return dataset
-
-    def _fetch_emnist_dataset(self, train: bool) -> EMNIST:
-        """Fetch the EMNIST dataset."""
-        if self.transform is None:
-            transform = Compose([Transpose(), ToTensor()])
-
-        dataset = EMNIST(
-            root=DATA_DIRNAME,
-            split="byclass",
-            train=train,
-            download=False,
-            transform=transform,
-            target_transform=self.target_transform,
-        )
-
-        if self.sample_to_balance:
-            dataset = self._sample_to_balance(dataset)
-
-        if self.subsample_fraction is not None:
-            dataset = self._subsample(dataset)
-
-        return dataset
-
     def _fetch_emnist_data_loaders(self) -> Dict[str, DataLoader]:
         """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders."""
         data_loaders = {}
@@ -193,10 +315,19 @@ class EmnistDataLoader:
                 else:
                     train = False
 
-                dataset = self._fetch_emnist_dataset(train)
+                emnist_dataset = EmnistDataset(
+                    train=train,
+                    sample_to_balance=self.sample_to_balance,
+                    subsample_fraction=self.subsample_fraction,
+                    transform=self.transform,
+                    target_transform=self.target_transform,
+                    seed=self.seed,
+                )
+
+                emnist_dataset.load_emnist_dataset()
 
                 data_loader = DataLoader(
-                    dataset=dataset,
+                    dataset=emnist_dataset,
                     batch_size=self.batch_size,
                     shuffle=self.shuffle,
                     num_workers=self.num_workers,
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
new file mode 100644
index 0000000..d49319f
--- /dev/null
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -0,0 +1,326 @@
+"""Emnist Lines dataset: synthetic handwritten lines dataset made from Emnist characters."""
+
+from collections import defaultdict
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import h5py
+from loguru import logger
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+from torchvision.transforms import Compose, Normalize, ToTensor
+
+from text_recognizer.datasets import DATA_DIRNAME, EmnistDataset, SentenceGenerator
+from text_recognizer.datasets.util import Transpose
+
+DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
+ESSENTIALS_FILENAME = (
+    Path(__file__).resolve().parents[0] / "emnist_lines_essentials.json"
+)
+
+
+class EmnistLinesDataset(Dataset):
+    """Synthetic dataset of lines from the Brown corpus with Emnist characters."""
+
+    def __init__(
+        self,
+        emnist: EmnistDataset,
+        train: bool = False,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        max_length: int = 34,
+        min_overlap: float = 0,
+        max_overlap: float = 0.33,
+        num_samples: int = 10000,
+        seed: int = 4711,
+    ) -> None:
+        """Short summary.
+
+        Args:
+            emnist (EmnistDataset): A EmnistDataset object.
+            train (bool): Flag for the filename. Defaults to False.
+            transform (Optional[Callable]): The transform of the data. Defaults to None.
+            target_transform (Optional[Callable]): The transform of the target. Defaults to None.
+            max_length (int): The maximum number of characters. Defaults to 34.
+            min_overlap (float): The minimum overlap between concatenated images. Defaults to 0.
+            max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33.
+            num_samples (int): Number of samples to generate. Defaults to 10000.
+            seed (int): Seed number. Defaults to 4711.
+
+        """
+        self.train = train
+        self.emnist = emnist
+
+        self.transform = transform
+        if self.transform is None:
+            self.transform = Compose([ToTensor()])
+
+        self.target_transform = target_transform
+        if self.target_transform is None:
+            self.target_transform = torch.tensor
+
+        self.mapping = self.emnist.mapping
+        self.num_classes = self.emnist.num_classes
+        self.max_length = max_length
+        self.min_overlap = min_overlap
+        self.max_overlap = max_overlap
+        self.num_samples = num_samples
+        self.input_shape = (
+            self.emnist.input_shape[0],
+            self.emnist.input_shape[1] * self.max_length,
+        )
+        self.output_shape = (self.max_length, self.num_classes)
+        self.seed = seed
+
+        # Placeholders for the generated dataset.
+        self.data = None
+        self.target = None
+
+    def __len__(self) -> int:
+        """Returns the length of the dataset."""
+        return len(self.data)
+
+    def __getitem__(
+        self, index: Union[int, torch.Tensor]
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Fetches data, target pair of the dataset for a given and index or indices.
+
+        Args:
+            index (Union[int, torch.Tensor]): Either a list or int of indices/index.
+
+        Returns:
+            Tuple[torch.Tensor, torch.Tensor]: Data target pair.
+
+        """
+        if torch.is_tensor(index):
+            index = index.tolist()
+
+        # data = np.array([self.data[index]])
+        data = self.data[index]
+        targets = self.targets[index]
+
+        if self.transform:
+            data = self.transform(data)
+
+        if self.target_transform:
+            targets = self.target_transform(targets)
+
+        return data, targets
+
+    def __repr__(self) -> str:
+        """Returns information about the dataset."""
+        return (
+            "EMNIST Lines Dataset\n"  # pylint: disable=no-member
+            f"Max length: {self.max_length}\n"
+            f"Min overlap: {self.min_overlap}\n"
+            f"Max overlap: {self.max_overlap}\n"
+            f"Num classes: {self.num_classes}\n"
+            f"Input shape: {self.input_shape}\n"
+            f"Data: {self.data.shape}\n"
+            f"Tagets: {self.targets.shape}\n"
+        )
+
+    @property
+    def data_filename(self) -> Path:
+        """Path to the h5 file."""
+        filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"
+        if self.train:
+            filename = "train_" + filename
+        else:
+            filename = "val_" + filename
+        return DATA_DIRNAME / filename
+
+    def _load_or_generate_data(self) -> None:
+        """Loads the dataset, if it does not exist a new dataset is generated before loading it."""
+        np.random.seed(self.seed)
+
+        if not self.data_filename.exists():
+            self._generate_data()
+        self._load_data()
+
+    def _load_data(self) -> None:
+        """Loads the dataset from the h5 file."""
+        logger.debug("EmnistLinesDataset loading data from HDF5...")
+        with h5py.File(self.data_filename, "r") as f:
+            self.data = f["data"][:]
+            self.targets = f["targets"][:]
+
+    def _generate_data(self) -> str:
+        """Generates a dataset with the Brown corpus and Emnist characters."""
+        logger.debug("Generating data...")
+
+        sentence_generator = SentenceGenerator(self.max_length)
+
+        # Load emnist dataset.
+        self.emnist.load_emnist_dataset()
+        samples_by_character = get_samples_by_character(
+            self.emnist.data.numpy(), self.emnist.targets.numpy(), self.emnist.mapping,
+        )
+
+        DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+        with h5py.File(self.data_filename, "a") as f:
+            data, targets = create_dataset_of_images(
+                self.num_samples,
+                samples_by_character,
+                sentence_generator,
+                self.min_overlap,
+                self.max_overlap,
+            )
+
+            targets = convert_strings_to_categorical_labels(
+                targets, self.emnist.inverse_mapping
+            )
+
+            f.create_dataset("data", data=data, dtype="u1", compression="lzf")
+            f.create_dataset("targets", data=targets, dtype="u1", compression="lzf")
+
+
+def get_samples_by_character(
+    samples: np.ndarray, labels: np.ndarray, mapping: Dict
+) -> defaultdict:
+    """Creates a dictionary with character as key and value as the list of images of that character.
+
+    Args:
+        samples (np.ndarray): Dataset of images of characters.
+        labels (np.ndarray): The labels for each image.
+        mapping (Dict): The Emnist mapping dictionary.
+
+    Returns:
+        defaultdict: A dictionary with characters as keys and list of images as values.
+
+    """
+    samples_by_character = defaultdict(list)
+    for sample, label in zip(samples, labels.flatten()):
+        samples_by_character[mapping[label]].append(sample)
+    return samples_by_character
+
+
+def select_letter_samples_for_string(
+    string: str, samples_by_character: Dict
+) -> List[np.ndarray]:
+    """Randomly selects Emnist characters to use for the senetence.
+
+    Args:
+        string (str): The word or sentence.
+        samples_by_character (Dict): The dictionary of emnist images of each character.
+
+    Returns:
+        List[np.ndarray]: A list of emnist images of the string.
+
+    """
+    zero_image = np.zeros((28, 28), np.uint8)
+    sample_image_by_character = {}
+    for character in string:
+        if character in sample_image_by_character:
+            continue
+        samples = samples_by_character[character]
+        sample = samples[np.random.choice(len(samples))] if samples else zero_image
+        sample_image_by_character[character] = sample.reshape(28, 28).swapaxes(0, 1)
+    return [sample_image_by_character[character] for character in string]
+
+
+def construct_image_from_string(
+    string: str, samples_by_character: Dict, min_overlap: float, max_overlap: float
+) -> np.ndarray:
+    """Concatenates images of the characters in the string.
+
+    The concatination is made with randomly selected overlap so that some portion of the character will overlap.
+
+    Args:
+        string (str): The word or sentence.
+        samples_by_character (Dict): The dictionary of emnist images of each character.
+        min_overlap (float): Minimum amount of overlap between Emnist images.
+        max_overlap (float): Maximum amount of overlap between Emnist images.
+
+    Returns:
+        np.ndarray: The Emnist image of the string.
+
+    """
+    overlap = np.random.uniform(min_overlap, max_overlap)
+    sampled_images = select_letter_samples_for_string(string, samples_by_character)
+    length = len(sampled_images)
+    height, width = sampled_images[0].shape
+    next_overlap_width = width - int(overlap * width)
+    concatenated_image = np.zeros((height, width * length), np.uint8)
+    x = 0
+    for image in sampled_images:
+        concatenated_image[:, x : (x + width)] += image
+        x += next_overlap_width
+    return np.minimum(255, concatenated_image)
+
+
+def create_dataset_of_images(
+    length: int,
+    samples_by_character: Dict,
+    sentence_generator: SentenceGenerator,
+    min_overlap: float,
+    max_overlap: float,
+) -> Tuple[np.ndarray, List[str]]:
+    """Creates a dataset with images and labels from strings generated from the SentenceGenerator.
+
+    Args:
+        length (int): The number of characters for each string.
+        samples_by_character (Dict): The dictionary of emnist images of each character.
+        sentence_generator (SentenceGenerator): A SentenceGenerator objest.
+        min_overlap (float): Minimum amount of overlap between Emnist images.
+        max_overlap (float): Maximum amount of overlap between Emnist images.
+
+    Returns:
+        Tuple[np.ndarray, List[str]]: A list of Emnist images and a list of the strings (labels).
+
+    Raises:
+        RuntimeError: If the sentence generator is not able to generate a string.
+
+    """
+    sample_label = sentence_generator.generate()
+    sample_image = construct_image_from_string(sample_label, samples_by_character, 0, 0)
+    images = np.zeros((length, sample_image.shape[0], sample_image.shape[1]), np.uint8)
+    labels = []
+    for n in range(length):
+        label = None
+        # Try several times to generate before actually throwing an error.
+        for _ in range(10):
+            try:
+                label = sentence_generator.generate()
+                break
+            except Exception:  # pylint: disable=broad-except
+                pass
+        if label is None:
+            raise RuntimeError("Was not able to generate a valid string.")
+        images[n] = construct_image_from_string(
+            label, samples_by_character, min_overlap, max_overlap
+        )
+        labels.append(label)
+    return images, labels
+
+
+def convert_strings_to_categorical_labels(
+    labels: List[str], mapping: Dict
+) -> np.ndarray:
+    """Translates a string of characters in to a target array of class int."""
+    return np.array([[mapping[c] for c in label] for label in labels])
+
+
+def create_datasets(
+    max_length: int = 34,
+    min_overlap: float = 0,
+    max_overlap: float = 0.33,
+    num_train: int = 10000,
+    num_val: int = 1000,
+) -> None:
+    """Creates a training an validation dataset of Emnist lines."""
+    emnist_train = EmnistDataset(train=True, sample_to_balance=True)
+    emnist_val = EmnistDataset(train=False, sample_to_balance=True)
+    datasets = [emnist_train, emnist_val]
+    num_samples = [num_train, num_val]
+    for num, train, dataset in zip(num_samples, [True, False], datasets):
+        emnist_lines = EmnistLinesDataset(
+            train=train,
+            emnist=dataset,
+            max_length=max_length,
+            min_overlap=min_overlap,
+            max_overlap=max_overlap,
+            num_samples=num,
+        )
+        emnist_lines._load_or_generate_data()
diff --git a/src/text_recognizer/datasets/sentence_generator.py b/src/text_recognizer/datasets/sentence_generator.py
new file mode 100644
index 0000000..ee86bd4
--- /dev/null
+++ b/src/text_recognizer/datasets/sentence_generator.py
@@ -0,0 +1,81 @@
+"""Downloading the Brown corpus with NLTK for sentence generating."""
+
+import itertools
+import re
+import string
+from typing import Optional
+
+import nltk
+from nltk.corpus.reader.util import ConcatenatedCorpusView
+import numpy as np
+
+from text_recognizer.datasets import DATA_DIRNAME
+
+NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk"
+
+
+class SentenceGenerator:
+    """Generates text sentences using the Brown corpus."""
+
+    def __init__(self, max_length: Optional[int] = None) -> None:
+        """Loads the corpus and sets word start indices."""
+        self.corpus = brown_corpus()
+        self.word_start_indices = [0] + [
+            _.start(0) + 1 for _ in re.finditer(" ", self.corpus)
+        ]
+        self.max_length = max_length
+
+    def generate(self, max_length: Optional[int] = None) -> str:
+        """Generates a word or sentences from the Brown corpus.
+
+        Sample a string from the Brown corpus of length at least one word and at most max_length, padding to
+        max_length with the '_' characters if sentence is shorter.
+
+        Args:
+            max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None.
+
+        Returns:
+            str: A sentence from the Brown corpus.
+
+        Raises:
+            ValueError: If max_length was not specified at initialization and not given as an argument.
+
+        """
+        if max_length is None:
+            max_length = self.max_length
+        if max_length is None:
+            raise ValueError(
+                "Must provide max_length to this method or when making this object."
+            )
+
+        index = np.random.randint(0, len(self.word_start_indices) - 1)
+        start_index = self.word_start_indices[index]
+        end_index_candidates = []
+        for index in range(index + 1, len(self.word_start_indices)):
+            if self.word_start_indices[index] - start_index > max_length:
+                break
+            end_index_candidates.append(self.word_start_indices[index])
+        end_index = np.random.choice(end_index_candidates)
+        sampled_text = self.corpus[start_index:end_index].strip()
+        padding = "_" * (max_length - len(sampled_text))
+        return sampled_text + padding
+
+
+def brown_corpus() -> str:
+    """Returns a single string with the Brown corpus with all punctuations stripped."""
+    sentences = load_nltk_brown_corpus()
+    corpus = " ".join(itertools.chain.from_iterable(sentences))
+    corpus = corpus.translate({ord(c): None for c in string.punctuation})
+    corpus = re.sub(" +", " ", corpus)
+    return corpus
+
+
+def load_nltk_brown_corpus() -> ConcatenatedCorpusView:
+    """Load the Brown corpus using the NLTK library."""
+    nltk.data.path.append(NLTK_DATA_DIRNAME)
+    try:
+        nltk.corpus.brown.sents()
+    except LookupError:
+        NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+        nltk.download("brown", download_dir=NLTK_DATA_DIRNAME)
+    return nltk.corpus.brown.sents()
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
new file mode 100644
index 0000000..6668eef
--- /dev/null
+++ b/src/text_recognizer/datasets/util.py
@@ -0,0 +1,11 @@
+"""Util functions for datasets."""
+import numpy as np
+from PIL import Image
+
+
+class Transpose:
+    """Transposes the EMNIST image to the correct orientation."""
+
+    def __call__(self, image: Image) -> np.ndarray:
+        """Swaps axis."""
+        return np.array(image).swapaxes(0, 1)
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index b78eacb..84a86ca 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -22,7 +22,7 @@ class Model(ABC):
     def __init__(
         self,
         network_fn: Type[nn.Module],
-        network_args: Dict,
+        network_args: Optional[Dict] = None,
         data_loader: Optional[Callable] = None,
         data_loader_args: Optional[Dict] = None,
         metrics: Optional[Dict] = None,
@@ -38,7 +38,7 @@ class Model(ABC):
 
         Args:
             network_fn (Type[nn.Module]): The PyTorch network.
-            network_args (Dict): Arguments for the network.
+            network_args (Optional[Dict]): Arguments for the network. Defaults to None.
             data_loader (Optional[Callable]): A function that fetches train and val DataLoader.
             data_loader_args (Optional[Dict]):  Arguments for the DataLoader.
             metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
@@ -58,18 +58,14 @@ class Model(ABC):
         if data_loader_args is not None:
             self._data_loaders = data_loader(**data_loader_args)
             dataset_name = self._data_loaders.__name__
+            self._mapping = self._data_loaders.mapping
         else:
+            self._mapping = None
             dataset_name = "*"
             self._data_loaders = None
 
         self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
 
-        # Extract the input shape for the torchsummary.
-        if isinstance(network_args["input_size"], int):
-            self._input_shape = (1,) + tuple([network_args["input_size"]])
-        else:
-            self._input_shape = (1,) + tuple(network_args["input_size"])
-
         if metrics is not None:
             self._metrics = metrics
 
@@ -80,8 +76,13 @@ class Model(ABC):
             self._device = device
 
         # Load network.
-        self.network_args = network_args
-        self._network = network_fn(**self.network_args)
+        self._network = None
+        self._network_args = network_args
+        # If no network arguemnts are given, load pretrained weights if they exist.
+        if self._network_args is None:
+            self.load_weights(network_fn)
+        else:
+            self._network = network_fn(**self._network_args)
 
         # To device.
         self._network.to(self._device)
@@ -104,8 +105,17 @@ class Model(ABC):
                 lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders("train"))
             self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
 
-        # Class mapping.
-        self._mapping = None
+        # Extract the input shape for the torchsummary.
+        if isinstance(self._network_args["input_size"], int):
+            self._input_shape = (1,) + tuple([self._network_args["input_size"]])
+        else:
+            self._input_shape = (1,) + tuple(self._network_args["input_size"])
+
+        # Experiment directory.
+        self.model_dir = None
+
+        # Flag for stopping training.
+        self.stop_training = False
 
     @property
     def __name__(self) -> str:
@@ -179,8 +189,13 @@ class Model(ABC):
     def _get_state_dict(self) -> Dict:
         """Get the state dict of the model."""
         state = {"model_state": self._network.state_dict()}
+
         if self._optimizer is not None:
             state["optimizer_state"] = self._optimizer.state_dict()
+
+        if self._lr_scheduler is not None:
+            state["scheduler_state"] = self._lr_scheduler.state_dict()
+
         return state
 
     def load_checkpoint(self, path: Path) -> int:
@@ -203,54 +218,63 @@ class Model(ABC):
         if self._optimizer is not None:
             self._optimizer.load_state_dict(checkpoint["optimizer_state"])
 
+        if self._lr_scheduler is not None:
+            self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
+
         epoch = checkpoint["epoch"]
 
         return epoch
 
-    def save_checkpoint(
-        self, path: Path, is_best: bool, epoch: int, val_metric: str
-    ) -> None:
+    def save_checkpoint(self, is_best: bool, epoch: int, val_metric: str) -> None:
         """Saves a checkpoint of the model.
 
         Args:
-            path (Path): Path to the experiment folder.
             is_best (bool): If it is the currently best model.
             epoch (int): The epoch of the checkpoint.
             val_metric (str): Validation metric.
 
+        Raises:
+            ValueError: If the self.model_dir is not set.
+
         """
         state = self._get_state_dict()
         state["is_best"] = is_best
         state["epoch"] = epoch
-        state["network_args"] = self.network_args
+        state["network_args"] = self._network_args
 
-        path.mkdir(parents=True, exist_ok=True)
+        if self.model_dir is None:
+            raise ValueError("Experiment directory is not set.")
+
+        self.model_dir.mkdir(parents=True, exist_ok=True)
 
         logger.debug("Saving checkpoint...")
-        filepath = str(path / "last.pt")
+        filepath = str(self.model_dir / "last.pt")
         torch.save(state, filepath)
 
         if is_best:
             logger.debug(
                 f"Found a new best {val_metric}. Saving best checkpoint and weights."
             )
-            shutil.copyfile(filepath, str(path / "best.pt"))
+            shutil.copyfile(filepath, str(self.model_dir / "best.pt"))
 
-    def load_weights(self) -> None:
+    def load_weights(self, network_fn: Type[nn.Module]) -> None:
         """Load the network weights."""
-        logger.debug("Loading network weights.")
+        logger.debug("Loading network with pretrained weights.")
         filename = glob(self.weights_filename)[0]
-        weights = torch.load(filename, map_location=torch.device(self._device))[
-            "model_state"
-        ]
+        if not filename:
+            raise FileNotFoundError(
+                f"Could not find any pretrained weights at {self.weights_filename}"
+            )
+        # Loading state directory.
+        state_dict = torch.load(filename, map_location=torch.device(self._device))
+        self._network_args = state_dict["network_args"]
+        weights = state_dict["model_state"]
+
+        # Initializes the network with trained weights.
+        self._network = network_fn(**self._network_args)
         self._network.load_state_dict(weights)
 
     def save_weights(self, path: Path) -> None:
         """Save the network weights."""
         logger.debug("Saving the best network weights.")
         shutil.copyfile(str(path / "best.pt"), self.weights_filename)
-
-    @abstractmethod
-    def load_mapping(self) -> None:
-        """Loads class mapping from network output to character."""
-        ...
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 527fc7d..f1dabb7 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -1,12 +1,15 @@
 """Defines the CharacterModel class."""
-from typing import Callable, Dict, Optional, Tuple, Type
+from typing import Callable, Dict, Optional, Tuple, Type, Union
 
 import numpy as np
 import torch
 from torch import nn
 from torchvision.transforms import ToTensor
 
-from text_recognizer.datasets.emnist_dataset import load_emnist_mapping
+from text_recognizer.datasets.emnist_dataset import (
+    _augment_emnist_mapping,
+    _load_emnist_essentials,
+)
 from text_recognizer.models.base import Model
 
 
@@ -16,7 +19,7 @@ class CharacterModel(Model):
     def __init__(
         self,
         network_fn: Type[nn.Module],
-        network_args: Dict,
+        network_args: Optional[Dict] = None,
         data_loader: Optional[Callable] = None,
         data_loader_args: Optional[Dict] = None,
         metrics: Optional[Dict] = None,
@@ -44,19 +47,23 @@ class CharacterModel(Model):
             lr_scheduler_args,
             device,
         )
-        self.load_mapping()
+        if self.mapping is None:
+            self.load_mapping()
         self.tensor_transform = ToTensor()
         self.softmax = nn.Softmax(dim=0)
 
     def load_mapping(self) -> None:
         """Mapping between integers and classes."""
-        self._mapping = load_emnist_mapping()
+        essentials = _load_emnist_essentials()
+        self._mapping = _augment_emnist_mapping(dict(essentials["mapping"]))
 
-    def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]:
+    def predict_on_image(
+        self, image: Union[np.ndarray, torch.Tensor]
+    ) -> Tuple[str, float]:
         """Character prediction on an image.
 
         Args:
-            image (np.ndarray): An image containing a character.
+            image (Union[np.ndarray, torch.Tensor]): An image containing a character.
 
         Returns:
             Tuple[str, float]: The predicted character and the confidence in the prediction.
@@ -64,12 +71,15 @@ class CharacterModel(Model):
         """
 
         if image.dtype == np.uint8:
-            image = (image / 255).astype(np.float32)
-
-        # Conver to Pytorch Tensor.
-        image = self.tensor_transform(image)
+            # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+            image = self.tensor_transform(image)
+        if image.dtype == torch.uint8:
+            # If the image is an unscaled tensor.
+            image = image.type("torch.FloatTensor") / 255
 
         with torch.no_grad():
+            # Put the image tensor on the device the model weights are on.
+            image = image.to(self.device)
             logits = self.network(image)
 
         prediction = self.softmax(logits.data.squeeze())
diff --git a/src/text_recognizer/tests/test_character_predictor.py b/src/text_recognizer/tests/test_character_predictor.py
index c603a3a..01bda78 100644
--- a/src/text_recognizer/tests/test_character_predictor.py
+++ b/src/text_recognizer/tests/test_character_predictor.py
@@ -4,7 +4,6 @@ import os
 from pathlib import Path
 import unittest
 
-import click
 from loguru import logger
 
 from text_recognizer.character_predictor import CharacterPredictor
@@ -18,19 +17,10 @@ os.environ["CUDA_VISIBLE_DEVICES"] = ""
 class TestCharacterPredictor(unittest.TestCase):
     """Tests for the CharacterPredictor class."""
 
-    # @click.command()
-    # @click.option(
-    #     "--network", type=str, help="Network to load, e.g. MLP or LeNet.", default="MLP"
-    # )
     def test_filename(self) -> None:
         """Test that CharacterPredictor correctly predicts on a single image, for serveral test images."""
-        network_module = importlib.import_module("text_recognizer.networks")
-        network_fn_ = getattr(network_module, "MLP")
-        # network_args = {"input_size": [28, 28], "output_size": 62, "dropout_rate": 0}
-        network_args = {"input_size": 784, "output_size": 62, "dropout_rate": 0.2}
-        predictor = CharacterPredictor(
-            network_fn=network_fn_, network_args=network_args
-        )
+        network_fn_ = MLP
+        predictor = CharacterPredictor(network_fn=network_fn_)
 
         for filename in SUPPORT_DIRNAME.glob("*.png"):
             pred, conf = predictor.predict(str(filename))
diff --git a/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt
index 43a3891..46b1cb1 100644
Binary files a/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt and b/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt differ
diff --git a/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt
index 0dde787..4ec12c1 100644
Binary files a/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt and b/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt differ
-- 
cgit v1.2.3-70-g09d2