diff options
Diffstat (limited to 'src/text_recognizer')
19 files changed, 781 insertions, 48 deletions
diff --git a/src/text_recognizer/character_predictor.py b/src/text_recognizer/character_predictor.py new file mode 100644 index 0000000..69ef896 --- /dev/null +++ b/src/text_recognizer/character_predictor.py @@ -0,0 +1,26 @@ +"""CharacterPredictor class.""" + +from typing import Tuple, Union + +import numpy as np + +from text_recognizer.models import CharacterModel +from text_recognizer.util import read_image + + +class CharacterPredictor: + """Recognizes the character in handwritten character images.""" + + def __init__(self) -> None: + """Intializes the CharacterModel and load the pretrained weights.""" + self.model = CharacterModel() + self.model.load_weights() + self.model.eval() + + def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]: + """Predict on a single images contianing a handwritten character.""" + if isinstance(image_or_filename, str): + image = read_image(image_or_filename, grayscale=True) + else: + image = image_or_filename + return self.model.predict_on_image(image) diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index cbaf1d9..929cfb9 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,2 +1,2 @@ """Dataset modules.""" -# from .emnist_dataset import fetch_dataloader +from .data_loader import fetch_data_loader diff --git a/src/text_recognizer/datasets/data_loader.py b/src/text_recognizer/datasets/data_loader.py new file mode 100644 index 0000000..fd55934 --- /dev/null +++ b/src/text_recognizer/datasets/data_loader.py @@ -0,0 +1,15 @@ +"""Data loader collection.""" + +from typing import Dict + +from torch.utils.data import DataLoader + +from text_recognizer.datasets.emnist_dataset import fetch_emnist_data_loader + + +def fetch_data_loader(data_loader_args: Dict) -> DataLoader: + """Fetches the specified PyTorch data loader.""" + if data_loader_args.pop("name").lower() == "emnist": + return fetch_emnist_data_loader(data_loader_args) + else: + raise NotImplementedError diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index 204faeb..f9c8ffa 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -1,72 +1,155 @@ """Fetches a PyTorch DataLoader with the EMNIST dataset.""" + +import json from pathlib import Path -from typing import Callable +from typing import Callable, Dict, List, Optional -import click from loguru import logger +import numpy as np +from PIL import Image from torch.utils.data import DataLoader from torchvision.datasets import EMNIST +from torchvision.transforms import Compose, ToTensor + + +DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" +ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" + + +class Transpose: + """Transposes the EMNIST image to the correct orientation.""" + + def __call__(self, image: Image) -> np.ndarray: + """Swaps axis.""" + return np.array(image).swapaxes(0, 1) + + +def save_emnist_essentials(emnsit_dataset: EMNIST) -> None: + """Extract and saves EMNIST essentials.""" + labels = emnsit_dataset.classes + labels.sort() + mapping = [(i, str(label)) for i, label in enumerate(labels)] + essentials = { + "mapping": mapping, + "input_shape": tuple(emnsit_dataset[0][0].shape[:]), + } + logger.info("Saving emnist essentials...") + with open(ESSENTIALS_FILENAME, "w") as f: + json.dump(essentials, f) -@click.command() -@click.option("--split", "-s", default="byclass") -def download_emnist(split: str) -> None: +def download_emnist() -> None: """Download the EMNIST dataset via the PyTorch class.""" - data_dir = Path(__file__).resolve().parents[3] / "data" - logger.debug(f"Data directory is: {data_dir}") - EMNIST(root=data_dir, split=split, download=True) + logger.info(f"Data directory is: {DATA_DIRNAME}") + dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True) + save_emnist_essentials(dataset) -def fetch_dataloader( - root: str, +def load_emnist_mapping() -> Dict: + """Load the EMNIST mapping.""" + with open(str(ESSENTIALS_FILENAME)) as f: + essentials = json.load(f) + return dict(essentials["mapping"]) + + +def _sample_to_balance(dataset: EMNIST, seed: int = 4711) -> None: + """Because the dataset is not balanced, we take at most the mean number of instances per class.""" + np.random.seed(seed) + x = dataset.data + y = dataset.targets + num_to_sample = int(np.bincount(y.flatten()).mean()) + all_sampled_inds = [] + for label in np.unique(y.flatten()): + inds = np.where(y == label)[0] + sampled_inds = np.unique(np.random.choice(inds, num_to_sample)) + all_sampled_inds.append(sampled_inds) + ind = np.concatenate(all_sampled_inds) + x_sampled = x[ind] + y_sampled = y[ind] + dataset.data = x_sampled + dataset.targets = y_sampled + + +def fetch_emnist_dataset( split: str, train: bool, - download: bool, - transform: Callable = None, - target_transform: Callable = None, + sample_to_balance: bool = False, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, +) -> EMNIST: + """Fetch the EMNIST dataset.""" + if transform is None: + transform = Compose([Transpose(), ToTensor()]) + + dataset = EMNIST( + root=DATA_DIRNAME, + split="byclass", + train=train, + download=False, + transform=transform, + target_transform=target_transform, + ) + + if sample_to_balance and split == "byclass": + _sample_to_balance(dataset) + + return dataset + + +def fetch_emnist_data_loader( + splits: List[str], + sample_to_balance: bool = False, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, batch_size: int = 128, shuffle: bool = False, num_workers: int = 0, cuda: bool = True, -) -> DataLoader: - """Down/load the EMNIST dataset and return a PyTorch DataLoader. +) -> Dict[DataLoader]: + """Fetches the EMNIST dataset and return a PyTorch DataLoader. Args: - root (str): Root directory of dataset where EMNIST/processed/training.pt and EMNIST/processed/test.pt - exist. - split (str): The dataset has 6 different splits: byclass, bymerge, balanced, letters, digits and mnist. - This argument specifies which one to use. - train (bool): If True, creates dataset from training.pt, otherwise from test.pt. - download (bool): If true, downloads the dataset from the internet and puts it in root directory. If - dataset is already downloaded, it is not downloaded again. - transform (Callable): A function/transform that takes in an PIL image and returns a transformed version. - E.g, transforms.RandomCrop. - target_transform (Callable): A function/transform that takes in the target and transforms it. - batch_size (int): How many samples per batch to load (the default is 128). - shuffle (bool): Set to True to have the data reshuffled at every epoch (the default is False). - num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be loaded in - the main process (default: 0). - cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning them. + splits (List[str]): One or both of the dataset splits "train" and "val". + sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected. + Defaults to False. + transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a + transformed version. E.g, transforms.RandomCrop. Defaults to None. + target_transform (Optional[Callable]): A function/transform that takes in the target and transforms + it. + Defaults to None. + batch_size (int): How many samples per batch to load. Defaults to 128. + shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. + num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be + loaded in the main process. Defaults to 0. + cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning + them. Defaults to True. Returns: - DataLoader: A PyTorch DataLoader with emnist characters. + Dict: A dict containing PyTorch DataLoader(s) with emnist characters. """ - dataset = EMNIST( - root=root, - split=split, - train=train, - download=download, - transform=transform, - target_transform=target_transform, - ) + data_loaders = {} - data_loader = DataLoader( - dataset=dataset, - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - pin_memory=cuda, - ) + for split in ["train", "val"]: + if split in splits: + + if split == "train": + train = True + else: + train = False + + dataset = fetch_emnist_dataset( + split, train, sample_to_balance, transform, target_transform + ) + + data_loader = DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=cuda, + ) + + data_loaders[split] = data_loader - return data_loader + return data_loaders diff --git a/src/text_recognizer/datasets/emnist_essentials.json b/src/text_recognizer/datasets/emnist_essentials.json new file mode 100644 index 0000000..2a0648a --- /dev/null +++ b/src/text_recognizer/datasets/emnist_essentials.json @@ -0,0 +1 @@ +{"mapping": [[0, "0"], [1, "1"], [2, "2"], [3, "3"], [4, "4"], [5, "5"], [6, "6"], [7, "7"], [8, "8"], [9, "9"], [10, "A"], [11, "B"], [12, "C"], [13, "D"], [14, "E"], [15, "F"], [16, "G"], [17, "H"], [18, "I"], [19, "J"], [20, "K"], [21, "L"], [22, "M"], [23, "N"], [24, "O"], [25, "P"], [26, "Q"], [27, "R"], [28, "S"], [29, "T"], [30, "U"], [31, "V"], [32, "W"], [33, "X"], [34, "Y"], [35, "Z"], [36, "a"], [37, "b"], [38, "c"], [39, "d"], [40, "e"], [41, "f"], [42, "g"], [43, "h"], [44, "i"], [45, "j"], [46, "k"], [47, "l"], [48, "m"], [49, "n"], [50, "o"], [51, "p"], [52, "q"], [53, "r"], [54, "s"], [55, "t"], [56, "u"], [57, "v"], [58, "w"], [59, "x"], [60, "y"], [61, "z"]], "input_shape": [28, 28]} diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index aa26de6..d265dcf 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -1 +1,2 @@ """Model modules.""" +from .character_model import CharacterModel diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py new file mode 100644 index 0000000..736af7b --- /dev/null +++ b/src/text_recognizer/models/base.py @@ -0,0 +1,230 @@ +"""Abstract Model class for PyTorch neural networks.""" + +from abc import ABC, abstractmethod +from pathlib import Path +import shutil +from typing import Callable, Dict, Optional, Tuple + +from loguru import logger +import torch +from torch import nn +from torchsummary import summary + +from text_recognizer.dataset.data_loader import fetch_data_loader + +WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" + + +class Model(ABC): + """Abstract Model class with composition of different parts defining a PyTorch neural network.""" + + def __init__( + self, + network_fn: Callable, + network_args: Dict, + data_loader_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + """Base class, to be inherited by predictors for specific type of data. + + Args: + network_fn (Callable): The PyTorch network. + network_args (Dict): Arguments for the network. + data_loader_args (Optional[Dict]): Arguments for the data loader. + metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. + criterion (Optional[Callable]): The criterion to evaulate the preformance of the network. + Defaults to None. + criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None. + optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None. + optimizer_args (Optional[Dict]): Dict of arguments for optimizer. Defaults to None. + lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None. + lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to + None. + device (Optional[str]): Name of the device to train on. Defaults to None. + + """ + + # Fetch data loaders. + if data_loader_args is not None: + self._data_loaders = fetch_data_loader(**data_loader_args) + dataset_name = self._data_loaders.items()[0].dataset.__name__ + else: + dataset_name = "" + self._data_loaders = None + + self.name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}" + + # Extract the input shape for the torchsummary. + self._input_shape = network_args.pop("input_shape") + + if metrics is not None: + self._metrics = metrics + + # Set the device. + if self.device is None: + self._device = torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu" + ) + else: + self._device = device + + # Load network. + self._network = network_fn(**network_args) + + # To device. + self._network.to(self._device) + + # Set criterion. + self._criterion = None + if criterion is not None: + self._criterion = criterion(**criterion_args) + + # Set optimizer. + self._optimizer = None + if optimizer is not None: + self._optimizer = optimizer(self._network.parameters(), **optimizer_args) + + # Set learning rate scheduler. + self._lr_scheduler = None + if lr_scheduler is not None: + self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args) + + @property + def input_shape(self) -> Tuple[int, ...]: + """The input shape.""" + return self._input_shape + + def eval(self) -> None: + """Sets the network to evaluation mode.""" + self._network.eval() + + def train(self) -> None: + """Sets the network to train mode.""" + self._network.train() + + @property + def device(self) -> str: + """Device where the weights are stored, i.e. cpu or cuda.""" + return self._device + + @property + def metrics(self) -> Optional[Dict]: + """Metrics.""" + return self._metrics + + @property + def criterion(self) -> Optional[Callable]: + """Criterion.""" + return self._criterion + + @property + def optimizer(self) -> Optional[Callable]: + """Optimizer.""" + return self._optimizer + + @property + def lr_scheduler(self) -> Optional[Callable]: + """Learning rate scheduler.""" + return self._lr_scheduler + + @property + def data_loaders(self) -> Optional[Dict]: + """Dataloaders.""" + return self._data_loaders + + @property + def network(self) -> nn.Module: + """Neural network.""" + return self._network + + @property + def weights_filename(self) -> str: + """Filepath to the network weights.""" + WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True) + return str(WEIGHT_DIRNAME / f"{self.name}_weights.pt") + + def summary(self) -> None: + """Prints a summary of the network architecture.""" + summary(self._network, self._input_shape, device=self.device) + + def _get_state(self) -> Dict: + """Get the state dict of the model.""" + state = {"model_state": self._network.state_dict()} + if self._optimizer is not None: + state["optimizer_state"] = self._optimizer.state_dict() + return state + + def load_checkpoint(self, path: Path) -> int: + """Load a previously saved checkpoint. + + Args: + path (Path): Path to the experiment with the checkpoint. + + Returns: + epoch (int): The last epoch when the checkpoint was created. + + """ + if not path.exists(): + logger.debug("File does not exist {str(path)}") + + checkpoint = torch.load(str(path)) + self._network.load_state_dict(checkpoint["model_state"]) + + if self._optimizer is not None: + self._optimizer.load_state_dict(checkpoint["optimizer_state"]) + + epoch = checkpoint["epoch"] + + return epoch + + def save_checkpoint( + self, path: Path, is_best: bool, epoch: int, val_metric: str + ) -> None: + """Saves a checkpoint of the model. + + Args: + path (Path): Path to the experiment folder. + is_best (bool): If it is the currently best model. + epoch (int): The epoch of the checkpoint. + val_metric (str): Validation metric. + + """ + state = self._get_state_dict() + state["is_best"] = is_best + state["epoch"] = epoch + + path.mkdir(parents=True, exist_ok=True) + + logger.debug("Saving checkpoint...") + filepath = str(path / "last.pt") + torch.save(state, filepath) + + if is_best: + logger.debug( + f"Found a new best {val_metric}. Saving best checkpoint and weights." + ) + self.save_weights() + shutil.copyfile(filepath, str(path / "best.pt")) + + def load_weights(self) -> None: + """Load the network weights.""" + logger.debug("Loading network weights.") + weights = torch.load(self.weights_filename)["model_state"] + self._network.load_state_dict(weights) + + def save_weights(self) -> None: + """Save the network weights.""" + logger.debug("Saving network weights.") + torch.save({"model_state": self._network.state_dict()}, self.weights_filename) + + @abstractmethod + def mapping(self) -> Dict: + """Mapping from network output to class.""" + ... diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py new file mode 100644 index 0000000..1570344 --- /dev/null +++ b/src/text_recognizer/models/character_model.py @@ -0,0 +1,71 @@ +"""Defines the CharacterModel class.""" +from typing import Callable, Dict, Optional, Tuple + +import numpy as np +import torch +from torch import nn +from torchvision.transforms import ToTensor + +from text_recognizer.datasets.emnist_dataset import load_emnist_mapping +from text_recognizer.models.base import Model +from text_recognizer.networks.mlp import mlp + + +class CharacterModel(Model): + """Model for predicting characters from images.""" + + def __init__( + self, + network_fn: Callable, + network_args: Dict, + data_loader_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + """Initializes the CharacterModel.""" + + super().__init__( + network_fn, + data_loader_args, + network_args, + metrics, + criterion, + optimizer, + device, + ) + self.emnist_mapping = self.mapping() + self.eval() + + def mapping(self) -> Dict: + """Mapping between integers and classes.""" + mapping = load_emnist_mapping() + return mapping + + def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]: + """Character prediction on an image. + + Args: + image (np.ndarray): An image containing a character. + + Returns: + Tuple[str, float]: The predicted character and the confidence in the prediction. + + """ + if image.dtype == np.uint8: + image = (image / 255).astype(np.float32) + + # Conver to Pytorch Tensor. + image = ToTensor(image) + + prediction = self.network(image) + index = torch.argmax(prediction, dim=1) + confidence_of_prediction = prediction[index] + predicted_character = self.emnist_mapping[index] + + return predicted_character, confidence_of_prediction diff --git a/src/text_recognizer/models/util.py b/src/text_recognizer/models/util.py new file mode 100644 index 0000000..905fe7b --- /dev/null +++ b/src/text_recognizer/models/util.py @@ -0,0 +1,19 @@ +"""Utility functions for models.""" + +import torch + + +def accuracy(outputs: torch.Tensor, labels: torch.Tensro) -> float: + """Short summary. + + Args: + outputs (torch.Tensor): The output from the network. + labels (torch.Tensor): Ground truth labels. + + Returns: + float: The accuracy for the batch. + + """ + _, predicted = torch.max(outputs.data, dim=1) + acc = (predicted == labels).sum().item() / labels.shape[0] + return acc diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py new file mode 100644 index 0000000..71d247f --- /dev/null +++ b/src/text_recognizer/networks/lenet.py @@ -0,0 +1,93 @@ +"""Defines the LeNet network.""" +from typing import Callable, Optional, Tuple + +import torch +from torch import nn + + +class Flatten(nn.Module): + """Flattens a tensor.""" + + def forward(self, x: int) -> torch.Tensor: + """Flattens a tensor for input to a nn.Linear layer.""" + return torch.flatten(x, start_dim=1) + + +class LeNet(nn.Module): + """LeNet network.""" + + def __init__( + self, + channels: Tuple[int, ...], + kernel_sizes: Tuple[int, ...], + hidden_size: Tuple[int, ...], + dropout_rate: float, + output_size: int, + activation_fn: Optional[Callable] = None, + ) -> None: + """The LeNet network. + + Args: + channels (Tuple[int, ...]): Channels in the convolutional layers. + kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. + hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers. + dropout_rate (float): The dropout rate. + output_size (int): Number of classes. + activation_fn (Optional[Callable]): The non-linear activation function. Defaults to + nn.ReLU(inplace). + + """ + super().__init__() + + if activation_fn is None: + activation_fn = nn.ReLU(inplace=True) + + self.layers = [ + nn.Conv2d( + in_channels=channels[0], + out_channels=channels[1], + kernel_size=kernel_sizes[0], + ), + activation_fn, + nn.Conv2d( + in_channels=channels[1], + out_channels=channels[2], + kernel_size=kernel_sizes[1], + ), + activation_fn, + nn.MaxPool2d(kernel_sizes[2]), + nn.Dropout(p=dropout_rate), + Flatten(), + nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]), + activation_fn, + nn.Dropout(p=dropout_rate), + nn.Linear(in_features=hidden_size[1], out_features=output_size), + ] + + self.layers = nn.Sequential(*self.layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """The feedforward.""" + return self.layers(x) + + +# def test(): +# x = torch.randn([1, 1, 28, 28]) +# channels = [1, 32, 64] +# kernel_sizes = [3, 3, 2] +# hidden_size = [9216, 128] +# output_size = 10 +# dropout_rate = 0.2 +# activation_fn = nn.ReLU() +# net = LeNet( +# channels=channels, +# kernel_sizes=kernel_sizes, +# dropout_rate=dropout_rate, +# hidden_size=hidden_size, +# output_size=output_size, +# activation_fn=activation_fn, +# ) +# from torchsummary import summary +# +# summary(net, (1, 28, 28), device="cpu") +# out = net(x) diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py new file mode 100644 index 0000000..2a41790 --- /dev/null +++ b/src/text_recognizer/networks/mlp.py @@ -0,0 +1,81 @@ +"""Defines the MLP network.""" +from typing import Callable, Optional + +import torch +from torch import nn + + +class MLP(nn.Module): + """Multi layered perceptron network.""" + + def __init__( + self, + input_size: int, + output_size: int, + hidden_size: int, + num_layers: int, + dropout_rate: float, + activation_fn: Optional[Callable] = None, + ) -> None: + """Initialization of the MLP network. + + Args: + input_size (int): The input shape of the network. + output_size (int): Number of classes in the dataset. + hidden_size (int): The number of `neurons` in each hidden layer. + num_layers (int): The number of hidden layers. + dropout_rate (float): The dropout rate at each layer. + activation_fn (Optional[Callable]): The activation function in the hidden layers, (default: + nn.ReLU()). + + """ + super().__init__() + + if activation_fn is None: + activation_fn = nn.ReLU(inplace=True) + + self.layers = [ + nn.Linear(in_features=input_size, out_features=hidden_size), + activation_fn, + ] + + for _ in range(num_layers): + self.layers += [ + nn.Linear(in_features=hidden_size, out_features=hidden_size), + activation_fn, + ] + + if dropout_rate: + self.layers.append(nn.Dropout(p=dropout_rate)) + + self.layers.append(nn.Linear(in_features=hidden_size, out_features=output_size)) + + self.layers = nn.Sequential(*self.layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """The feedforward.""" + x = torch.flatten(x, start_dim=1) + return self.layers(x) + + +# def test(): +# x = torch.randn([1, 28, 28]) +# input_size = torch.flatten(x).shape[0] +# output_size = 10 +# hidden_size = 128 +# num_layers = 5 +# dropout_rate = 0.25 +# activation_fn = nn.GELU() +# net = MLP( +# input_size=input_size, +# output_size=output_size, +# hidden_size=hidden_size, +# num_layers=num_layers, +# dropout_rate=dropout_rate, +# activation_fn=activation_fn, +# ) +# from torchsummary import summary +# +# summary(net, (1, 28, 28), device="cpu") +# +# out = net(x) diff --git a/src/text_recognizer/tests/__init__.py b/src/text_recognizer/tests/__init__.py new file mode 100644 index 0000000..18ff212 --- /dev/null +++ b/src/text_recognizer/tests/__init__.py @@ -0,0 +1 @@ +"""Test modules for the text text recognizer.""" diff --git a/src/text_recognizer/tests/support/__init__.py b/src/text_recognizer/tests/support/__init__.py new file mode 100644 index 0000000..a265ede --- /dev/null +++ b/src/text_recognizer/tests/support/__init__.py @@ -0,0 +1,2 @@ +"""Support file modules.""" +from .create_emnist_support_files import create_emnist_support_files diff --git a/src/text_recognizer/tests/support/create_emnist_support_files.py b/src/text_recognizer/tests/support/create_emnist_support_files.py new file mode 100644 index 0000000..5dd1a81 --- /dev/null +++ b/src/text_recognizer/tests/support/create_emnist_support_files.py @@ -0,0 +1,33 @@ +"""Module for creating EMNIST test support files.""" +from pathlib import Path +import shutil + +from text_recognizer.datasets.emnist_dataset import ( + fetch_emnist_dataset, + load_emnist_mapping, +) +from text_recognizer.util import write_image + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist" + + +def create_emnist_support_files() -> None: + """Create support images for test of CharacterPredictor class.""" + shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True) + SUPPORT_DIRNAME.mkdir() + + dataset = fetch_emnist_dataset(split="byclass", train=False) + mapping = load_emnist_mapping() + + for index in [5, 7, 9]: + image, label = dataset[index] + if len(image.shape) == 3: + image = image.squeeze(0) + image = image.numpy() + label = mapping[int(label)] + print(index, label) + write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) + + +if __name__ == "__main__": + create_emnist_support_files() diff --git a/src/text_recognizer/tests/support/emnist/8.png b/src/text_recognizer/tests/support/emnist/8.png Binary files differnew file mode 100644 index 0000000..faa29aa --- /dev/null +++ b/src/text_recognizer/tests/support/emnist/8.png diff --git a/src/text_recognizer/tests/support/emnist/U.png b/src/text_recognizer/tests/support/emnist/U.png Binary files differnew file mode 100644 index 0000000..304eaec --- /dev/null +++ b/src/text_recognizer/tests/support/emnist/U.png diff --git a/src/text_recognizer/tests/support/emnist/e.png b/src/text_recognizer/tests/support/emnist/e.png Binary files differnew file mode 100644 index 0000000..a03ecd4 --- /dev/null +++ b/src/text_recognizer/tests/support/emnist/e.png diff --git a/src/text_recognizer/tests/test_character_predictor.py b/src/text_recognizer/tests/test_character_predictor.py new file mode 100644 index 0000000..7c094ef --- /dev/null +++ b/src/text_recognizer/tests/test_character_predictor.py @@ -0,0 +1,26 @@ +"""Test for CharacterPredictor class.""" +import os +from pathlib import Path +import unittest + +from text_recognizer.character_predictor import CharacterPredictor + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "emnist" + +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +class TestCharacterPredictor(unittest.TestCase): + """Tests for the CharacterPredictor class.""" + + def test_filename(self) -> None: + """Test that CharacterPredictor correctly predicts on a single image, for serveral test images.""" + predictor = CharacterPredictor() + + for filename in SUPPORT_DIRNAME.glob("*.png"): + pred, conf = predictor.predict(str(filename)) + print( + f"Prediction: {pred} at confidence: {conf} for image with character {filename.stem}" + ) + self.assertEqual(pred, filename.stem) + self.assertGreater(conf, 0.7) diff --git a/src/text_recognizer/util.py b/src/text_recognizer/util.py new file mode 100644 index 0000000..52fa1e4 --- /dev/null +++ b/src/text_recognizer/util.py @@ -0,0 +1,51 @@ +"""Utility functions for text_recognizer module.""" +import os +from pathlib import Path +from typing import Union +from urllib.request import urlopen + +import cv2 +import numpy as np + + +def read_image(image_uri: Union[Path, str], grayscale: bool = False) -> np.ndarray: + """Read image_uri.""" + + def read_image_from_filename(image_filename: str, imread_flag: int) -> np.ndarray: + return cv2.imread(str(image_filename), imread_flag) + + def read_image_from_url(image_url: str, imread_flag: int) -> np.ndarray: + if image_url.lower().startswith("http"): + url_response = urlopen(str(image_url)) + image_array = np.array(bytearray(url_response.read()), dtype=np.uint8) + return cv2.imdecode(image_array, imread_flag) + else: + raise ValueError( + "Url does not start with http, therfore not safe to open..." + ) from None + + imread_flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR + local_file = os.path.exsits(image_uri) + try: + image = None + if local_file: + image = read_image_from_filename(image_uri, imread_flag) + else: + image = read_image_from_url(image_uri, imread_flag) + assert image is not None + except Exception as e: + raise ValueError(f"Could not load image at {image_uri}: {e}") + return image + + +def rescale_image(image: np.ndarray) -> np.ndarray: + """Rescale image from [0, 1] to [0, 255].""" + if image.max() <= 1.0: + image = 255 * (image - image.min()) / (image.max() - image.min()) + return image + + +def write_image(image: np.ndarray, filename: Union[Path, str]) -> None: + """Write image to file.""" + image = rescale_image(image) + cv2.imwrite(str(filename), image) |