From 07dd14116fe1d8148fb614b160245287533620fc Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Mon, 3 Aug 2020 23:33:34 +0200 Subject: Working Emnist lines dataset. --- src/text_recognizer/character_predictor.py | 5 +- src/text_recognizer/datasets/__init__.py | 24 +- src/text_recognizer/datasets/emnist_dataset.py | 279 +++++++++++++----- .../datasets/emnist_lines_dataset.py | 326 +++++++++++++++++++++ src/text_recognizer/datasets/sentence_generator.py | 81 +++++ src/text_recognizer/datasets/util.py | 11 + src/text_recognizer/models/base.py | 84 ++++-- src/text_recognizer/models/character_model.py | 32 +- .../tests/test_character_predictor.py | 14 +- .../weights/CharacterModel_Emnist_LeNet_weights.pt | Bin 14483400 -> 14485305 bytes .../weights/CharacterModel_Emnist_MLP_weights.pt | Bin 1702233 -> 1704096 bytes 11 files changed, 724 insertions(+), 132 deletions(-) create mode 100644 src/text_recognizer/datasets/emnist_lines_dataset.py create mode 100644 src/text_recognizer/datasets/sentence_generator.py create mode 100644 src/text_recognizer/datasets/util.py (limited to 'src/text_recognizer') diff --git a/src/text_recognizer/character_predictor.py b/src/text_recognizer/character_predictor.py index a773f36..b733a53 100644 --- a/src/text_recognizer/character_predictor.py +++ b/src/text_recognizer/character_predictor.py @@ -11,10 +11,9 @@ from text_recognizer.util import read_image class CharacterPredictor: """Recognizes the character in handwritten character images.""" - def __init__(self, network_fn: Type[nn.Module], network_args: Dict) -> None: + def __init__(self, network_fn: Type[nn.Module]) -> None: """Intializes the CharacterModel and load the pretrained weights.""" - self.model = CharacterModel(network_fn=network_fn, network_args=network_args) - self.model.load_weights() + self.model = CharacterModel(network_fn=network_fn) self.model.eval() def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]: diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index 795be90..bfa6a02 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,4 +1,24 @@ """Dataset modules.""" -from .emnist_dataset import EmnistDataLoader +from .emnist_dataset import ( + DATA_DIRNAME, + EmnistDataLoaders, + EmnistDataset, +) +from .emnist_lines_dataset import ( + construct_image_from_string, + EmnistLinesDataset, + get_samples_by_character, +) +from .sentence_generator import SentenceGenerator +from .util import Transpose -__all__ = ["EmnistDataLoader"] +__all__ = [ + "construct_image_from_string", + "DATA_DIRNAME", + "EmnistDataset", + "EmnistDataLoaders", + "EmnistLinesDataset", + "get_samples_by_character", + "SentenceGenerator", + "Transpose", +] diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index b92b57d..525df95 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -1,29 +1,23 @@ -"""Fetches a PyTorch DataLoader with the EMNIST dataset.""" +"""Emnist dataset: black and white images of handwritten characters (Aa-Zz) and digits (0-9).""" import json from pathlib import Path -from typing import Callable, Dict, List, Optional, Type +from typing import Callable, Dict, List, Optional, Tuple, Type, Union from loguru import logger import numpy as np from PIL import Image -from torch.utils.data import DataLoader +import torch +from torch.utils.data import DataLoader, Dataset from torchvision.datasets import EMNIST -from torchvision.transforms import Compose, ToTensor +from torchvision.transforms import Compose, Normalize, ToTensor +from text_recognizer.datasets.util import Transpose DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" -class Transpose: - """Transposes the EMNIST image to the correct orientation.""" - - def __call__(self, image: Image) -> np.ndarray: - """Swaps axis.""" - return np.array(image).swapaxes(0, 1) - - def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: """Extract and saves EMNIST essentials.""" labels = emnsit_dataset.classes @@ -45,14 +39,187 @@ def download_emnist() -> None: save_emnist_essentials(dataset) -def load_emnist_mapping() -> Dict[int, str]: +def _load_emnist_essentials() -> Dict: """Load the EMNIST mapping.""" with open(str(ESSENTIALS_FILENAME)) as f: essentials = json.load(f) - return dict(essentials["mapping"]) + return essentials + + +def _augment_emnist_mapping(mapping: Dict) -> Dict: + """Augment the mapping with extra symbols.""" + # Extra symbols in IAM dataset + extra_symbols = [ + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] + + # padding symbol + extra_symbols.append("_") + + max_key = max(mapping.keys()) + extra_mapping = {} + for i, symbol in enumerate(extra_symbols): + extra_mapping[max_key + 1 + i] = symbol + + return {**mapping, **extra_mapping} + + +class EmnistDataset(Dataset): + """This is a class for resampling and subsampling the PyTorch EMNIST dataset.""" + + def __init__( + self, + train: bool = False, + sample_to_balance: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + seed: int = 4711, + ) -> None: + """Loads the dataset and the mappings. + + Args: + train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to + False. + sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False. + subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. + transform (Optional[Callable]): Transform(s) for input data. Defaults to None. + target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. + seed (int): Seed number. Defaults to 4711. + + Raises: + ValueError: If subsample_fraction is not None and outside the range (0, 1). + + """ + + self.train = train + self.sample_to_balance = sample_to_balance + if subsample_fraction is not None: + if not 0.0 < subsample_fraction < 1.0: + raise ValueError("The subsample fraction must be in (0, 1).") + self.subsample_fraction = subsample_fraction + self.transform = transform + if self.transform is None: + self.transform = Compose([Transpose(), ToTensor()]) + + self.target_transform = target_transform + self.seed = seed + + # Load dataset infromation. + essentials = _load_emnist_essentials() + self.mapping = _augment_emnist_mapping(dict(essentials["mapping"])) + self.inverse_mapping = {v: k for k, v in self.mapping.items()} + self.num_classes = len(self.mapping) + self.input_shape = essentials["input_shape"] + + # Placeholders + self.data = None + self.targets = None + + def __len__(self) -> int: + """Returns the length of the dataset.""" + return len(self.data) + + def __getitem__( + self, index: Union[int, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Fetches samples from the dataset. + + Args: + index (Union[int, torch.Tensor]): The indices of the samples to fetch. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Data target tuple. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets + + def __repr__(self) -> str: + """Returns information about the dataset.""" + return ( + "EMNIST Dataset\n" + f"Num classes: {self.num_classes}\n" + f"Mapping: {self.mapping}\n" + f"Input shape: {self.input_shape}\n" + ) + + def _sample_to_balance(self) -> None: + """Because the dataset is not balanced, we take at most the mean number of instances per class.""" + np.random.seed(self.seed) + x = self.data + y = self.targets + num_to_sample = int(np.bincount(y.flatten()).mean()) + all_sampled_indices = [] + for label in np.unique(y.flatten()): + inds = np.where(y == label)[0] + sampled_indices = np.unique(np.random.choice(inds, num_to_sample)) + all_sampled_indices.append(sampled_indices) + indices = np.concatenate(all_sampled_indices) + x_sampled = x[indices] + y_sampled = y[indices] + self.data = x_sampled + self.targets = y_sampled + + def _subsample(self) -> None: + """Subsamples the dataset to the specified fraction.""" + x = self.data + y = self.targets + num_samples = int(x.shape[0] * self.subsample_fraction) + x_sampled = x[:num_samples] + y_sampled = y[:num_samples] + self.data = x_sampled + self.targets = y_sampled + + def load_emnist_dataset(self) -> None: + """Fetch the EMNIST dataset.""" + dataset = EMNIST( + root=DATA_DIRNAME, + split="byclass", + train=self.train, + download=False, + transform=None, + target_transform=None, + ) + + self.data = dataset.data + self.targets = dataset.targets + + if self.sample_to_balance: + self._sample_to_balance() + + if self.subsample_fraction is not None: + self._subsample() -class EmnistDataLoader: +class EmnistDataLoaders: """Class for Emnist DataLoaders.""" def __init__( @@ -68,7 +235,7 @@ class EmnistDataLoader: cuda: bool = True, seed: int = 4711, ) -> None: - """Fetches DataLoaders. + """Fetches DataLoaders for given split(s). Args: splits (List[str]): One or both of the dataset splits "train" and "val". @@ -88,13 +255,17 @@ class EmnistDataLoader: them. Defaults to True. seed (int): Seed for sampling. + Raises: + ValueError: If subsample_fraction is not None and outside the range (0, 1). + """ self.splits = splits self.sample_to_balance = sample_to_balance + if subsample_fraction is not None: - assert ( - 0.0 < subsample_fraction < 1.0 - ), " The subsample fraction must be in (0, 1)." + if not 0.0 < subsample_fraction < 1.0: + raise ValueError("The subsample fraction must be in (0, 1).") + self.subsample_fraction = subsample_fraction self.transform = transform self.target_transform = target_transform @@ -105,6 +276,10 @@ class EmnistDataLoader: self.seed = seed self._data_loaders = self._fetch_emnist_data_loaders() + def __repr__(self) -> str: + """Returns information about the dataset.""" + return self._data_loaders[self.splits[0]].dataset.__repr__() + @property def __name__(self) -> str: """Returns the name of the dataset.""" @@ -128,59 +303,6 @@ class EmnistDataLoader: except KeyError: raise ValueError(f"Split {split} does not exist.") - def _sample_to_balance(self, dataset: type = EMNIST) -> EMNIST: - """Because the dataset is not balanced, we take at most the mean number of instances per class.""" - np.random.seed(self.seed) - x = dataset.data - y = dataset.targets - num_to_sample = int(np.bincount(y.flatten()).mean()) - all_sampled_indices = [] - for label in np.unique(y.flatten()): - inds = np.where(y == label)[0] - sampled_indices = np.unique(np.random.choice(inds, num_to_sample)) - all_sampled_indices.append(sampled_indices) - indices = np.concatenate(all_sampled_indices) - x_sampled = x[indices] - y_sampled = y[indices] - dataset.data = x_sampled - dataset.targets = y_sampled - - return dataset - - def _subsample(self, dataset: type = EMNIST) -> EMNIST: - """Subsamples the dataset to the specified fraction.""" - x = dataset.data - y = dataset.targets - num_samples = int(x.shape[0] * self.subsample_fraction) - x_sampled = x[:num_samples] - y_sampled = y[:num_samples] - dataset.data = x_sampled - dataset.targets = y_sampled - - return dataset - - def _fetch_emnist_dataset(self, train: bool) -> EMNIST: - """Fetch the EMNIST dataset.""" - if self.transform is None: - transform = Compose([Transpose(), ToTensor()]) - - dataset = EMNIST( - root=DATA_DIRNAME, - split="byclass", - train=train, - download=False, - transform=transform, - target_transform=self.target_transform, - ) - - if self.sample_to_balance: - dataset = self._sample_to_balance(dataset) - - if self.subsample_fraction is not None: - dataset = self._subsample(dataset) - - return dataset - def _fetch_emnist_data_loaders(self) -> Dict[str, DataLoader]: """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders.""" data_loaders = {} @@ -193,10 +315,19 @@ class EmnistDataLoader: else: train = False - dataset = self._fetch_emnist_dataset(train) + emnist_dataset = EmnistDataset( + train=train, + sample_to_balance=self.sample_to_balance, + subsample_fraction=self.subsample_fraction, + transform=self.transform, + target_transform=self.target_transform, + seed=self.seed, + ) + + emnist_dataset.load_emnist_dataset() data_loader = DataLoader( - dataset=dataset, + dataset=emnist_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py new file mode 100644 index 0000000..d49319f --- /dev/null +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -0,0 +1,326 @@ +"""Emnist Lines dataset: synthetic handwritten lines dataset made from Emnist characters.""" + +from collections import defaultdict +from pathlib import Path +from typing import Callable, Dict, List, Optional, Tuple, Union + +import h5py +from loguru import logger +import numpy as np +import torch +from torch.utils.data import Dataset +from torchvision.transforms import Compose, Normalize, ToTensor + +from text_recognizer.datasets import DATA_DIRNAME, EmnistDataset, SentenceGenerator +from text_recognizer.datasets.util import Transpose + +DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" +ESSENTIALS_FILENAME = ( + Path(__file__).resolve().parents[0] / "emnist_lines_essentials.json" +) + + +class EmnistLinesDataset(Dataset): + """Synthetic dataset of lines from the Brown corpus with Emnist characters.""" + + def __init__( + self, + emnist: EmnistDataset, + train: bool = False, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + max_length: int = 34, + min_overlap: float = 0, + max_overlap: float = 0.33, + num_samples: int = 10000, + seed: int = 4711, + ) -> None: + """Short summary. + + Args: + emnist (EmnistDataset): A EmnistDataset object. + train (bool): Flag for the filename. Defaults to False. + transform (Optional[Callable]): The transform of the data. Defaults to None. + target_transform (Optional[Callable]): The transform of the target. Defaults to None. + max_length (int): The maximum number of characters. Defaults to 34. + min_overlap (float): The minimum overlap between concatenated images. Defaults to 0. + max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33. + num_samples (int): Number of samples to generate. Defaults to 10000. + seed (int): Seed number. Defaults to 4711. + + """ + self.train = train + self.emnist = emnist + + self.transform = transform + if self.transform is None: + self.transform = Compose([ToTensor()]) + + self.target_transform = target_transform + if self.target_transform is None: + self.target_transform = torch.tensor + + self.mapping = self.emnist.mapping + self.num_classes = self.emnist.num_classes + self.max_length = max_length + self.min_overlap = min_overlap + self.max_overlap = max_overlap + self.num_samples = num_samples + self.input_shape = ( + self.emnist.input_shape[0], + self.emnist.input_shape[1] * self.max_length, + ) + self.output_shape = (self.max_length, self.num_classes) + self.seed = seed + + # Placeholders for the generated dataset. + self.data = None + self.target = None + + def __len__(self) -> int: + """Returns the length of the dataset.""" + return len(self.data) + + def __getitem__( + self, index: Union[int, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Fetches data, target pair of the dataset for a given and index or indices. + + Args: + index (Union[int, torch.Tensor]): Either a list or int of indices/index. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Data target pair. + + """ + if torch.is_tensor(index): + index = index.tolist() + + # data = np.array([self.data[index]]) + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets + + def __repr__(self) -> str: + """Returns information about the dataset.""" + return ( + "EMNIST Lines Dataset\n" # pylint: disable=no-member + f"Max length: {self.max_length}\n" + f"Min overlap: {self.min_overlap}\n" + f"Max overlap: {self.max_overlap}\n" + f"Num classes: {self.num_classes}\n" + f"Input shape: {self.input_shape}\n" + f"Data: {self.data.shape}\n" + f"Tagets: {self.targets.shape}\n" + ) + + @property + def data_filename(self) -> Path: + """Path to the h5 file.""" + filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt" + if self.train: + filename = "train_" + filename + else: + filename = "val_" + filename + return DATA_DIRNAME / filename + + def _load_or_generate_data(self) -> None: + """Loads the dataset, if it does not exist a new dataset is generated before loading it.""" + np.random.seed(self.seed) + + if not self.data_filename.exists(): + self._generate_data() + self._load_data() + + def _load_data(self) -> None: + """Loads the dataset from the h5 file.""" + logger.debug("EmnistLinesDataset loading data from HDF5...") + with h5py.File(self.data_filename, "r") as f: + self.data = f["data"][:] + self.targets = f["targets"][:] + + def _generate_data(self) -> str: + """Generates a dataset with the Brown corpus and Emnist characters.""" + logger.debug("Generating data...") + + sentence_generator = SentenceGenerator(self.max_length) + + # Load emnist dataset. + self.emnist.load_emnist_dataset() + samples_by_character = get_samples_by_character( + self.emnist.data.numpy(), self.emnist.targets.numpy(), self.emnist.mapping, + ) + + DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + with h5py.File(self.data_filename, "a") as f: + data, targets = create_dataset_of_images( + self.num_samples, + samples_by_character, + sentence_generator, + self.min_overlap, + self.max_overlap, + ) + + targets = convert_strings_to_categorical_labels( + targets, self.emnist.inverse_mapping + ) + + f.create_dataset("data", data=data, dtype="u1", compression="lzf") + f.create_dataset("targets", data=targets, dtype="u1", compression="lzf") + + +def get_samples_by_character( + samples: np.ndarray, labels: np.ndarray, mapping: Dict +) -> defaultdict: + """Creates a dictionary with character as key and value as the list of images of that character. + + Args: + samples (np.ndarray): Dataset of images of characters. + labels (np.ndarray): The labels for each image. + mapping (Dict): The Emnist mapping dictionary. + + Returns: + defaultdict: A dictionary with characters as keys and list of images as values. + + """ + samples_by_character = defaultdict(list) + for sample, label in zip(samples, labels.flatten()): + samples_by_character[mapping[label]].append(sample) + return samples_by_character + + +def select_letter_samples_for_string( + string: str, samples_by_character: Dict +) -> List[np.ndarray]: + """Randomly selects Emnist characters to use for the senetence. + + Args: + string (str): The word or sentence. + samples_by_character (Dict): The dictionary of emnist images of each character. + + Returns: + List[np.ndarray]: A list of emnist images of the string. + + """ + zero_image = np.zeros((28, 28), np.uint8) + sample_image_by_character = {} + for character in string: + if character in sample_image_by_character: + continue + samples = samples_by_character[character] + sample = samples[np.random.choice(len(samples))] if samples else zero_image + sample_image_by_character[character] = sample.reshape(28, 28).swapaxes(0, 1) + return [sample_image_by_character[character] for character in string] + + +def construct_image_from_string( + string: str, samples_by_character: Dict, min_overlap: float, max_overlap: float +) -> np.ndarray: + """Concatenates images of the characters in the string. + + The concatination is made with randomly selected overlap so that some portion of the character will overlap. + + Args: + string (str): The word or sentence. + samples_by_character (Dict): The dictionary of emnist images of each character. + min_overlap (float): Minimum amount of overlap between Emnist images. + max_overlap (float): Maximum amount of overlap between Emnist images. + + Returns: + np.ndarray: The Emnist image of the string. + + """ + overlap = np.random.uniform(min_overlap, max_overlap) + sampled_images = select_letter_samples_for_string(string, samples_by_character) + length = len(sampled_images) + height, width = sampled_images[0].shape + next_overlap_width = width - int(overlap * width) + concatenated_image = np.zeros((height, width * length), np.uint8) + x = 0 + for image in sampled_images: + concatenated_image[:, x : (x + width)] += image + x += next_overlap_width + return np.minimum(255, concatenated_image) + + +def create_dataset_of_images( + length: int, + samples_by_character: Dict, + sentence_generator: SentenceGenerator, + min_overlap: float, + max_overlap: float, +) -> Tuple[np.ndarray, List[str]]: + """Creates a dataset with images and labels from strings generated from the SentenceGenerator. + + Args: + length (int): The number of characters for each string. + samples_by_character (Dict): The dictionary of emnist images of each character. + sentence_generator (SentenceGenerator): A SentenceGenerator objest. + min_overlap (float): Minimum amount of overlap between Emnist images. + max_overlap (float): Maximum amount of overlap between Emnist images. + + Returns: + Tuple[np.ndarray, List[str]]: A list of Emnist images and a list of the strings (labels). + + Raises: + RuntimeError: If the sentence generator is not able to generate a string. + + """ + sample_label = sentence_generator.generate() + sample_image = construct_image_from_string(sample_label, samples_by_character, 0, 0) + images = np.zeros((length, sample_image.shape[0], sample_image.shape[1]), np.uint8) + labels = [] + for n in range(length): + label = None + # Try several times to generate before actually throwing an error. + for _ in range(10): + try: + label = sentence_generator.generate() + break + except Exception: # pylint: disable=broad-except + pass + if label is None: + raise RuntimeError("Was not able to generate a valid string.") + images[n] = construct_image_from_string( + label, samples_by_character, min_overlap, max_overlap + ) + labels.append(label) + return images, labels + + +def convert_strings_to_categorical_labels( + labels: List[str], mapping: Dict +) -> np.ndarray: + """Translates a string of characters in to a target array of class int.""" + return np.array([[mapping[c] for c in label] for label in labels]) + + +def create_datasets( + max_length: int = 34, + min_overlap: float = 0, + max_overlap: float = 0.33, + num_train: int = 10000, + num_val: int = 1000, +) -> None: + """Creates a training an validation dataset of Emnist lines.""" + emnist_train = EmnistDataset(train=True, sample_to_balance=True) + emnist_val = EmnistDataset(train=False, sample_to_balance=True) + datasets = [emnist_train, emnist_val] + num_samples = [num_train, num_val] + for num, train, dataset in zip(num_samples, [True, False], datasets): + emnist_lines = EmnistLinesDataset( + train=train, + emnist=dataset, + max_length=max_length, + min_overlap=min_overlap, + max_overlap=max_overlap, + num_samples=num, + ) + emnist_lines._load_or_generate_data() diff --git a/src/text_recognizer/datasets/sentence_generator.py b/src/text_recognizer/datasets/sentence_generator.py new file mode 100644 index 0000000..ee86bd4 --- /dev/null +++ b/src/text_recognizer/datasets/sentence_generator.py @@ -0,0 +1,81 @@ +"""Downloading the Brown corpus with NLTK for sentence generating.""" + +import itertools +import re +import string +from typing import Optional + +import nltk +from nltk.corpus.reader.util import ConcatenatedCorpusView +import numpy as np + +from text_recognizer.datasets import DATA_DIRNAME + +NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk" + + +class SentenceGenerator: + """Generates text sentences using the Brown corpus.""" + + def __init__(self, max_length: Optional[int] = None) -> None: + """Loads the corpus and sets word start indices.""" + self.corpus = brown_corpus() + self.word_start_indices = [0] + [ + _.start(0) + 1 for _ in re.finditer(" ", self.corpus) + ] + self.max_length = max_length + + def generate(self, max_length: Optional[int] = None) -> str: + """Generates a word or sentences from the Brown corpus. + + Sample a string from the Brown corpus of length at least one word and at most max_length, padding to + max_length with the '_' characters if sentence is shorter. + + Args: + max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None. + + Returns: + str: A sentence from the Brown corpus. + + Raises: + ValueError: If max_length was not specified at initialization and not given as an argument. + + """ + if max_length is None: + max_length = self.max_length + if max_length is None: + raise ValueError( + "Must provide max_length to this method or when making this object." + ) + + index = np.random.randint(0, len(self.word_start_indices) - 1) + start_index = self.word_start_indices[index] + end_index_candidates = [] + for index in range(index + 1, len(self.word_start_indices)): + if self.word_start_indices[index] - start_index > max_length: + break + end_index_candidates.append(self.word_start_indices[index]) + end_index = np.random.choice(end_index_candidates) + sampled_text = self.corpus[start_index:end_index].strip() + padding = "_" * (max_length - len(sampled_text)) + return sampled_text + padding + + +def brown_corpus() -> str: + """Returns a single string with the Brown corpus with all punctuations stripped.""" + sentences = load_nltk_brown_corpus() + corpus = " ".join(itertools.chain.from_iterable(sentences)) + corpus = corpus.translate({ord(c): None for c in string.punctuation}) + corpus = re.sub(" +", " ", corpus) + return corpus + + +def load_nltk_brown_corpus() -> ConcatenatedCorpusView: + """Load the Brown corpus using the NLTK library.""" + nltk.data.path.append(NLTK_DATA_DIRNAME) + try: + nltk.corpus.brown.sents() + except LookupError: + NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) + return nltk.corpus.brown.sents() diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py new file mode 100644 index 0000000..6668eef --- /dev/null +++ b/src/text_recognizer/datasets/util.py @@ -0,0 +1,11 @@ +"""Util functions for datasets.""" +import numpy as np +from PIL import Image + + +class Transpose: + """Transposes the EMNIST image to the correct orientation.""" + + def __call__(self, image: Image) -> np.ndarray: + """Swaps axis.""" + return np.array(image).swapaxes(0, 1) diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index b78eacb..84a86ca 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -22,7 +22,7 @@ class Model(ABC): def __init__( self, network_fn: Type[nn.Module], - network_args: Dict, + network_args: Optional[Dict] = None, data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, metrics: Optional[Dict] = None, @@ -38,7 +38,7 @@ class Model(ABC): Args: network_fn (Type[nn.Module]): The PyTorch network. - network_args (Dict): Arguments for the network. + network_args (Optional[Dict]): Arguments for the network. Defaults to None. data_loader (Optional[Callable]): A function that fetches train and val DataLoader. data_loader_args (Optional[Dict]): Arguments for the DataLoader. metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. @@ -58,18 +58,14 @@ class Model(ABC): if data_loader_args is not None: self._data_loaders = data_loader(**data_loader_args) dataset_name = self._data_loaders.__name__ + self._mapping = self._data_loaders.mapping else: + self._mapping = None dataset_name = "*" self._data_loaders = None self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}" - # Extract the input shape for the torchsummary. - if isinstance(network_args["input_size"], int): - self._input_shape = (1,) + tuple([network_args["input_size"]]) - else: - self._input_shape = (1,) + tuple(network_args["input_size"]) - if metrics is not None: self._metrics = metrics @@ -80,8 +76,13 @@ class Model(ABC): self._device = device # Load network. - self.network_args = network_args - self._network = network_fn(**self.network_args) + self._network = None + self._network_args = network_args + # If no network arguemnts are given, load pretrained weights if they exist. + if self._network_args is None: + self.load_weights(network_fn) + else: + self._network = network_fn(**self._network_args) # To device. self._network.to(self._device) @@ -104,8 +105,17 @@ class Model(ABC): lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders("train")) self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args) - # Class mapping. - self._mapping = None + # Extract the input shape for the torchsummary. + if isinstance(self._network_args["input_size"], int): + self._input_shape = (1,) + tuple([self._network_args["input_size"]]) + else: + self._input_shape = (1,) + tuple(self._network_args["input_size"]) + + # Experiment directory. + self.model_dir = None + + # Flag for stopping training. + self.stop_training = False @property def __name__(self) -> str: @@ -179,8 +189,13 @@ class Model(ABC): def _get_state_dict(self) -> Dict: """Get the state dict of the model.""" state = {"model_state": self._network.state_dict()} + if self._optimizer is not None: state["optimizer_state"] = self._optimizer.state_dict() + + if self._lr_scheduler is not None: + state["scheduler_state"] = self._lr_scheduler.state_dict() + return state def load_checkpoint(self, path: Path) -> int: @@ -203,54 +218,63 @@ class Model(ABC): if self._optimizer is not None: self._optimizer.load_state_dict(checkpoint["optimizer_state"]) + if self._lr_scheduler is not None: + self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) + epoch = checkpoint["epoch"] return epoch - def save_checkpoint( - self, path: Path, is_best: bool, epoch: int, val_metric: str - ) -> None: + def save_checkpoint(self, is_best: bool, epoch: int, val_metric: str) -> None: """Saves a checkpoint of the model. Args: - path (Path): Path to the experiment folder. is_best (bool): If it is the currently best model. epoch (int): The epoch of the checkpoint. val_metric (str): Validation metric. + Raises: + ValueError: If the self.model_dir is not set. + """ state = self._get_state_dict() state["is_best"] = is_best state["epoch"] = epoch - state["network_args"] = self.network_args + state["network_args"] = self._network_args - path.mkdir(parents=True, exist_ok=True) + if self.model_dir is None: + raise ValueError("Experiment directory is not set.") + + self.model_dir.mkdir(parents=True, exist_ok=True) logger.debug("Saving checkpoint...") - filepath = str(path / "last.pt") + filepath = str(self.model_dir / "last.pt") torch.save(state, filepath) if is_best: logger.debug( f"Found a new best {val_metric}. Saving best checkpoint and weights." ) - shutil.copyfile(filepath, str(path / "best.pt")) + shutil.copyfile(filepath, str(self.model_dir / "best.pt")) - def load_weights(self) -> None: + def load_weights(self, network_fn: Type[nn.Module]) -> None: """Load the network weights.""" - logger.debug("Loading network weights.") + logger.debug("Loading network with pretrained weights.") filename = glob(self.weights_filename)[0] - weights = torch.load(filename, map_location=torch.device(self._device))[ - "model_state" - ] + if not filename: + raise FileNotFoundError( + f"Could not find any pretrained weights at {self.weights_filename}" + ) + # Loading state directory. + state_dict = torch.load(filename, map_location=torch.device(self._device)) + self._network_args = state_dict["network_args"] + weights = state_dict["model_state"] + + # Initializes the network with trained weights. + self._network = network_fn(**self._network_args) self._network.load_state_dict(weights) def save_weights(self, path: Path) -> None: """Save the network weights.""" logger.debug("Saving the best network weights.") shutil.copyfile(str(path / "best.pt"), self.weights_filename) - - @abstractmethod - def load_mapping(self) -> None: - """Loads class mapping from network output to character.""" - ... diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 527fc7d..f1dabb7 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -1,12 +1,15 @@ """Defines the CharacterModel class.""" -from typing import Callable, Dict, Optional, Tuple, Type +from typing import Callable, Dict, Optional, Tuple, Type, Union import numpy as np import torch from torch import nn from torchvision.transforms import ToTensor -from text_recognizer.datasets.emnist_dataset import load_emnist_mapping +from text_recognizer.datasets.emnist_dataset import ( + _augment_emnist_mapping, + _load_emnist_essentials, +) from text_recognizer.models.base import Model @@ -16,7 +19,7 @@ class CharacterModel(Model): def __init__( self, network_fn: Type[nn.Module], - network_args: Dict, + network_args: Optional[Dict] = None, data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, metrics: Optional[Dict] = None, @@ -44,19 +47,23 @@ class CharacterModel(Model): lr_scheduler_args, device, ) - self.load_mapping() + if self.mapping is None: + self.load_mapping() self.tensor_transform = ToTensor() self.softmax = nn.Softmax(dim=0) def load_mapping(self) -> None: """Mapping between integers and classes.""" - self._mapping = load_emnist_mapping() + essentials = _load_emnist_essentials() + self._mapping = _augment_emnist_mapping(dict(essentials["mapping"])) - def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]: + def predict_on_image( + self, image: Union[np.ndarray, torch.Tensor] + ) -> Tuple[str, float]: """Character prediction on an image. Args: - image (np.ndarray): An image containing a character. + image (Union[np.ndarray, torch.Tensor]): An image containing a character. Returns: Tuple[str, float]: The predicted character and the confidence in the prediction. @@ -64,12 +71,15 @@ class CharacterModel(Model): """ if image.dtype == np.uint8: - image = (image / 255).astype(np.float32) - - # Conver to Pytorch Tensor. - image = self.tensor_transform(image) + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 with torch.no_grad(): + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) logits = self.network(image) prediction = self.softmax(logits.data.squeeze()) diff --git a/src/text_recognizer/tests/test_character_predictor.py b/src/text_recognizer/tests/test_character_predictor.py index c603a3a..01bda78 100644 --- a/src/text_recognizer/tests/test_character_predictor.py +++ b/src/text_recognizer/tests/test_character_predictor.py @@ -4,7 +4,6 @@ import os from pathlib import Path import unittest -import click from loguru import logger from text_recognizer.character_predictor import CharacterPredictor @@ -18,19 +17,10 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "" class TestCharacterPredictor(unittest.TestCase): """Tests for the CharacterPredictor class.""" - # @click.command() - # @click.option( - # "--network", type=str, help="Network to load, e.g. MLP or LeNet.", default="MLP" - # ) def test_filename(self) -> None: """Test that CharacterPredictor correctly predicts on a single image, for serveral test images.""" - network_module = importlib.import_module("text_recognizer.networks") - network_fn_ = getattr(network_module, "MLP") - # network_args = {"input_size": [28, 28], "output_size": 62, "dropout_rate": 0} - network_args = {"input_size": 784, "output_size": 62, "dropout_rate": 0.2} - predictor = CharacterPredictor( - network_fn=network_fn_, network_args=network_args - ) + network_fn_ = MLP + predictor = CharacterPredictor(network_fn=network_fn_) for filename in SUPPORT_DIRNAME.glob("*.png"): pred, conf = predictor.predict(str(filename)) diff --git a/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt index 43a3891..46b1cb1 100644 Binary files a/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt and b/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt differ diff --git a/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt index 0dde787..4ec12c1 100644 Binary files a/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt and b/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt differ -- cgit v1.2.3-70-g09d2