summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-06-23 22:39:54 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-06-23 22:39:54 +0200
commit7c4de6d88664d2ea1b084f316a11896dde3e1150 (patch)
treecbde7e64c8064e9b523dfb65cd6c487d061ec805 /src/text_recognizer
parenta7a9ce59fc37317dd74c3a52caf7c4e68e570ee8 (diff)
latest
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/character_predictor.py26
-rw-r--r--src/text_recognizer/datasets/__init__.py2
-rw-r--r--src/text_recognizer/datasets/data_loader.py15
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py177
-rw-r--r--src/text_recognizer/datasets/emnist_essentials.json1
-rw-r--r--src/text_recognizer/models/__init__.py1
-rw-r--r--src/text_recognizer/models/base.py230
-rw-r--r--src/text_recognizer/models/character_model.py71
-rw-r--r--src/text_recognizer/models/util.py19
-rw-r--r--src/text_recognizer/networks/lenet.py93
-rw-r--r--src/text_recognizer/networks/mlp.py81
-rw-r--r--src/text_recognizer/tests/__init__.py1
-rw-r--r--src/text_recognizer/tests/support/__init__.py2
-rw-r--r--src/text_recognizer/tests/support/create_emnist_support_files.py33
-rw-r--r--src/text_recognizer/tests/support/emnist/8.pngbin0 -> 498 bytes
-rw-r--r--src/text_recognizer/tests/support/emnist/U.pngbin0 -> 524 bytes
-rw-r--r--src/text_recognizer/tests/support/emnist/e.pngbin0 -> 563 bytes
-rw-r--r--src/text_recognizer/tests/test_character_predictor.py26
-rw-r--r--src/text_recognizer/util.py51
19 files changed, 781 insertions, 48 deletions
diff --git a/src/text_recognizer/character_predictor.py b/src/text_recognizer/character_predictor.py
new file mode 100644
index 0000000..69ef896
--- /dev/null
+++ b/src/text_recognizer/character_predictor.py
@@ -0,0 +1,26 @@
+"""CharacterPredictor class."""
+
+from typing import Tuple, Union
+
+import numpy as np
+
+from text_recognizer.models import CharacterModel
+from text_recognizer.util import read_image
+
+
+class CharacterPredictor:
+ """Recognizes the character in handwritten character images."""
+
+ def __init__(self) -> None:
+ """Intializes the CharacterModel and load the pretrained weights."""
+ self.model = CharacterModel()
+ self.model.load_weights()
+ self.model.eval()
+
+ def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]:
+ """Predict on a single images contianing a handwritten character."""
+ if isinstance(image_or_filename, str):
+ image = read_image(image_or_filename, grayscale=True)
+ else:
+ image = image_or_filename
+ return self.model.predict_on_image(image)
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index cbaf1d9..929cfb9 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -1,2 +1,2 @@
"""Dataset modules."""
-# from .emnist_dataset import fetch_dataloader
+from .data_loader import fetch_data_loader
diff --git a/src/text_recognizer/datasets/data_loader.py b/src/text_recognizer/datasets/data_loader.py
new file mode 100644
index 0000000..fd55934
--- /dev/null
+++ b/src/text_recognizer/datasets/data_loader.py
@@ -0,0 +1,15 @@
+"""Data loader collection."""
+
+from typing import Dict
+
+from torch.utils.data import DataLoader
+
+from text_recognizer.datasets.emnist_dataset import fetch_emnist_data_loader
+
+
+def fetch_data_loader(data_loader_args: Dict) -> DataLoader:
+ """Fetches the specified PyTorch data loader."""
+ if data_loader_args.pop("name").lower() == "emnist":
+ return fetch_emnist_data_loader(data_loader_args)
+ else:
+ raise NotImplementedError
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index 204faeb..f9c8ffa 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -1,72 +1,155 @@
"""Fetches a PyTorch DataLoader with the EMNIST dataset."""
+
+import json
from pathlib import Path
-from typing import Callable
+from typing import Callable, Dict, List, Optional
-import click
from loguru import logger
+import numpy as np
+from PIL import Image
from torch.utils.data import DataLoader
from torchvision.datasets import EMNIST
+from torchvision.transforms import Compose, ToTensor
+
+
+DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
+ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
+
+
+class Transpose:
+ """Transposes the EMNIST image to the correct orientation."""
+
+ def __call__(self, image: Image) -> np.ndarray:
+ """Swaps axis."""
+ return np.array(image).swapaxes(0, 1)
+
+
+def save_emnist_essentials(emnsit_dataset: EMNIST) -> None:
+ """Extract and saves EMNIST essentials."""
+ labels = emnsit_dataset.classes
+ labels.sort()
+ mapping = [(i, str(label)) for i, label in enumerate(labels)]
+ essentials = {
+ "mapping": mapping,
+ "input_shape": tuple(emnsit_dataset[0][0].shape[:]),
+ }
+ logger.info("Saving emnist essentials...")
+ with open(ESSENTIALS_FILENAME, "w") as f:
+ json.dump(essentials, f)
-@click.command()
-@click.option("--split", "-s", default="byclass")
-def download_emnist(split: str) -> None:
+def download_emnist() -> None:
"""Download the EMNIST dataset via the PyTorch class."""
- data_dir = Path(__file__).resolve().parents[3] / "data"
- logger.debug(f"Data directory is: {data_dir}")
- EMNIST(root=data_dir, split=split, download=True)
+ logger.info(f"Data directory is: {DATA_DIRNAME}")
+ dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True)
+ save_emnist_essentials(dataset)
-def fetch_dataloader(
- root: str,
+def load_emnist_mapping() -> Dict:
+ """Load the EMNIST mapping."""
+ with open(str(ESSENTIALS_FILENAME)) as f:
+ essentials = json.load(f)
+ return dict(essentials["mapping"])
+
+
+def _sample_to_balance(dataset: EMNIST, seed: int = 4711) -> None:
+ """Because the dataset is not balanced, we take at most the mean number of instances per class."""
+ np.random.seed(seed)
+ x = dataset.data
+ y = dataset.targets
+ num_to_sample = int(np.bincount(y.flatten()).mean())
+ all_sampled_inds = []
+ for label in np.unique(y.flatten()):
+ inds = np.where(y == label)[0]
+ sampled_inds = np.unique(np.random.choice(inds, num_to_sample))
+ all_sampled_inds.append(sampled_inds)
+ ind = np.concatenate(all_sampled_inds)
+ x_sampled = x[ind]
+ y_sampled = y[ind]
+ dataset.data = x_sampled
+ dataset.targets = y_sampled
+
+
+def fetch_emnist_dataset(
split: str,
train: bool,
- download: bool,
- transform: Callable = None,
- target_transform: Callable = None,
+ sample_to_balance: bool = False,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+) -> EMNIST:
+ """Fetch the EMNIST dataset."""
+ if transform is None:
+ transform = Compose([Transpose(), ToTensor()])
+
+ dataset = EMNIST(
+ root=DATA_DIRNAME,
+ split="byclass",
+ train=train,
+ download=False,
+ transform=transform,
+ target_transform=target_transform,
+ )
+
+ if sample_to_balance and split == "byclass":
+ _sample_to_balance(dataset)
+
+ return dataset
+
+
+def fetch_emnist_data_loader(
+ splits: List[str],
+ sample_to_balance: bool = False,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
batch_size: int = 128,
shuffle: bool = False,
num_workers: int = 0,
cuda: bool = True,
-) -> DataLoader:
- """Down/load the EMNIST dataset and return a PyTorch DataLoader.
+) -> Dict[DataLoader]:
+ """Fetches the EMNIST dataset and return a PyTorch DataLoader.
Args:
- root (str): Root directory of dataset where EMNIST/processed/training.pt and EMNIST/processed/test.pt
- exist.
- split (str): The dataset has 6 different splits: byclass, bymerge, balanced, letters, digits and mnist.
- This argument specifies which one to use.
- train (bool): If True, creates dataset from training.pt, otherwise from test.pt.
- download (bool): If true, downloads the dataset from the internet and puts it in root directory. If
- dataset is already downloaded, it is not downloaded again.
- transform (Callable): A function/transform that takes in an PIL image and returns a transformed version.
- E.g, transforms.RandomCrop.
- target_transform (Callable): A function/transform that takes in the target and transforms it.
- batch_size (int): How many samples per batch to load (the default is 128).
- shuffle (bool): Set to True to have the data reshuffled at every epoch (the default is False).
- num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be loaded in
- the main process (default: 0).
- cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning them.
+ splits (List[str]): One or both of the dataset splits "train" and "val".
+ sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected.
+ Defaults to False.
+ transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a
+ transformed version. E.g, transforms.RandomCrop. Defaults to None.
+ target_transform (Optional[Callable]): A function/transform that takes in the target and transforms
+ it.
+ Defaults to None.
+ batch_size (int): How many samples per batch to load. Defaults to 128.
+ shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False.
+ num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be
+ loaded in the main process. Defaults to 0.
+ cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning
+ them. Defaults to True.
Returns:
- DataLoader: A PyTorch DataLoader with emnist characters.
+ Dict: A dict containing PyTorch DataLoader(s) with emnist characters.
"""
- dataset = EMNIST(
- root=root,
- split=split,
- train=train,
- download=download,
- transform=transform,
- target_transform=target_transform,
- )
+ data_loaders = {}
- data_loader = DataLoader(
- dataset=dataset,
- batch_size=batch_size,
- shuffle=shuffle,
- num_workers=num_workers,
- pin_memory=cuda,
- )
+ for split in ["train", "val"]:
+ if split in splits:
+
+ if split == "train":
+ train = True
+ else:
+ train = False
+
+ dataset = fetch_emnist_dataset(
+ split, train, sample_to_balance, transform, target_transform
+ )
+
+ data_loader = DataLoader(
+ dataset=dataset,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ num_workers=num_workers,
+ pin_memory=cuda,
+ )
+
+ data_loaders[split] = data_loader
- return data_loader
+ return data_loaders
diff --git a/src/text_recognizer/datasets/emnist_essentials.json b/src/text_recognizer/datasets/emnist_essentials.json
new file mode 100644
index 0000000..2a0648a
--- /dev/null
+++ b/src/text_recognizer/datasets/emnist_essentials.json
@@ -0,0 +1 @@
+{"mapping": [[0, "0"], [1, "1"], [2, "2"], [3, "3"], [4, "4"], [5, "5"], [6, "6"], [7, "7"], [8, "8"], [9, "9"], [10, "A"], [11, "B"], [12, "C"], [13, "D"], [14, "E"], [15, "F"], [16, "G"], [17, "H"], [18, "I"], [19, "J"], [20, "K"], [21, "L"], [22, "M"], [23, "N"], [24, "O"], [25, "P"], [26, "Q"], [27, "R"], [28, "S"], [29, "T"], [30, "U"], [31, "V"], [32, "W"], [33, "X"], [34, "Y"], [35, "Z"], [36, "a"], [37, "b"], [38, "c"], [39, "d"], [40, "e"], [41, "f"], [42, "g"], [43, "h"], [44, "i"], [45, "j"], [46, "k"], [47, "l"], [48, "m"], [49, "n"], [50, "o"], [51, "p"], [52, "q"], [53, "r"], [54, "s"], [55, "t"], [56, "u"], [57, "v"], [58, "w"], [59, "x"], [60, "y"], [61, "z"]], "input_shape": [28, 28]}
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index aa26de6..d265dcf 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -1 +1,2 @@
"""Model modules."""
+from .character_model import CharacterModel
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
new file mode 100644
index 0000000..736af7b
--- /dev/null
+++ b/src/text_recognizer/models/base.py
@@ -0,0 +1,230 @@
+"""Abstract Model class for PyTorch neural networks."""
+
+from abc import ABC, abstractmethod
+from pathlib import Path
+import shutil
+from typing import Callable, Dict, Optional, Tuple
+
+from loguru import logger
+import torch
+from torch import nn
+from torchsummary import summary
+
+from text_recognizer.dataset.data_loader import fetch_data_loader
+
+WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
+
+
+class Model(ABC):
+ """Abstract Model class with composition of different parts defining a PyTorch neural network."""
+
+ def __init__(
+ self,
+ network_fn: Callable,
+ network_args: Dict,
+ data_loader_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ """Base class, to be inherited by predictors for specific type of data.
+
+ Args:
+ network_fn (Callable): The PyTorch network.
+ network_args (Dict): Arguments for the network.
+ data_loader_args (Optional[Dict]): Arguments for the data loader.
+ metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
+ criterion (Optional[Callable]): The criterion to evaulate the preformance of the network.
+ Defaults to None.
+ criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None.
+ optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None.
+ optimizer_args (Optional[Dict]): Dict of arguments for optimizer. Defaults to None.
+ lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None.
+ lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to
+ None.
+ device (Optional[str]): Name of the device to train on. Defaults to None.
+
+ """
+
+ # Fetch data loaders.
+ if data_loader_args is not None:
+ self._data_loaders = fetch_data_loader(**data_loader_args)
+ dataset_name = self._data_loaders.items()[0].dataset.__name__
+ else:
+ dataset_name = ""
+ self._data_loaders = None
+
+ self.name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
+
+ # Extract the input shape for the torchsummary.
+ self._input_shape = network_args.pop("input_shape")
+
+ if metrics is not None:
+ self._metrics = metrics
+
+ # Set the device.
+ if self.device is None:
+ self._device = torch.device(
+ "cuda:0" if torch.cuda.is_available() else "cpu"
+ )
+ else:
+ self._device = device
+
+ # Load network.
+ self._network = network_fn(**network_args)
+
+ # To device.
+ self._network.to(self._device)
+
+ # Set criterion.
+ self._criterion = None
+ if criterion is not None:
+ self._criterion = criterion(**criterion_args)
+
+ # Set optimizer.
+ self._optimizer = None
+ if optimizer is not None:
+ self._optimizer = optimizer(self._network.parameters(), **optimizer_args)
+
+ # Set learning rate scheduler.
+ self._lr_scheduler = None
+ if lr_scheduler is not None:
+ self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
+
+ @property
+ def input_shape(self) -> Tuple[int, ...]:
+ """The input shape."""
+ return self._input_shape
+
+ def eval(self) -> None:
+ """Sets the network to evaluation mode."""
+ self._network.eval()
+
+ def train(self) -> None:
+ """Sets the network to train mode."""
+ self._network.train()
+
+ @property
+ def device(self) -> str:
+ """Device where the weights are stored, i.e. cpu or cuda."""
+ return self._device
+
+ @property
+ def metrics(self) -> Optional[Dict]:
+ """Metrics."""
+ return self._metrics
+
+ @property
+ def criterion(self) -> Optional[Callable]:
+ """Criterion."""
+ return self._criterion
+
+ @property
+ def optimizer(self) -> Optional[Callable]:
+ """Optimizer."""
+ return self._optimizer
+
+ @property
+ def lr_scheduler(self) -> Optional[Callable]:
+ """Learning rate scheduler."""
+ return self._lr_scheduler
+
+ @property
+ def data_loaders(self) -> Optional[Dict]:
+ """Dataloaders."""
+ return self._data_loaders
+
+ @property
+ def network(self) -> nn.Module:
+ """Neural network."""
+ return self._network
+
+ @property
+ def weights_filename(self) -> str:
+ """Filepath to the network weights."""
+ WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True)
+ return str(WEIGHT_DIRNAME / f"{self.name}_weights.pt")
+
+ def summary(self) -> None:
+ """Prints a summary of the network architecture."""
+ summary(self._network, self._input_shape, device=self.device)
+
+ def _get_state(self) -> Dict:
+ """Get the state dict of the model."""
+ state = {"model_state": self._network.state_dict()}
+ if self._optimizer is not None:
+ state["optimizer_state"] = self._optimizer.state_dict()
+ return state
+
+ def load_checkpoint(self, path: Path) -> int:
+ """Load a previously saved checkpoint.
+
+ Args:
+ path (Path): Path to the experiment with the checkpoint.
+
+ Returns:
+ epoch (int): The last epoch when the checkpoint was created.
+
+ """
+ if not path.exists():
+ logger.debug("File does not exist {str(path)}")
+
+ checkpoint = torch.load(str(path))
+ self._network.load_state_dict(checkpoint["model_state"])
+
+ if self._optimizer is not None:
+ self._optimizer.load_state_dict(checkpoint["optimizer_state"])
+
+ epoch = checkpoint["epoch"]
+
+ return epoch
+
+ def save_checkpoint(
+ self, path: Path, is_best: bool, epoch: int, val_metric: str
+ ) -> None:
+ """Saves a checkpoint of the model.
+
+ Args:
+ path (Path): Path to the experiment folder.
+ is_best (bool): If it is the currently best model.
+ epoch (int): The epoch of the checkpoint.
+ val_metric (str): Validation metric.
+
+ """
+ state = self._get_state_dict()
+ state["is_best"] = is_best
+ state["epoch"] = epoch
+
+ path.mkdir(parents=True, exist_ok=True)
+
+ logger.debug("Saving checkpoint...")
+ filepath = str(path / "last.pt")
+ torch.save(state, filepath)
+
+ if is_best:
+ logger.debug(
+ f"Found a new best {val_metric}. Saving best checkpoint and weights."
+ )
+ self.save_weights()
+ shutil.copyfile(filepath, str(path / "best.pt"))
+
+ def load_weights(self) -> None:
+ """Load the network weights."""
+ logger.debug("Loading network weights.")
+ weights = torch.load(self.weights_filename)["model_state"]
+ self._network.load_state_dict(weights)
+
+ def save_weights(self) -> None:
+ """Save the network weights."""
+ logger.debug("Saving network weights.")
+ torch.save({"model_state": self._network.state_dict()}, self.weights_filename)
+
+ @abstractmethod
+ def mapping(self) -> Dict:
+ """Mapping from network output to class."""
+ ...
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
new file mode 100644
index 0000000..1570344
--- /dev/null
+++ b/src/text_recognizer/models/character_model.py
@@ -0,0 +1,71 @@
+"""Defines the CharacterModel class."""
+from typing import Callable, Dict, Optional, Tuple
+
+import numpy as np
+import torch
+from torch import nn
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets.emnist_dataset import load_emnist_mapping
+from text_recognizer.models.base import Model
+from text_recognizer.networks.mlp import mlp
+
+
+class CharacterModel(Model):
+ """Model for predicting characters from images."""
+
+ def __init__(
+ self,
+ network_fn: Callable,
+ network_args: Dict,
+ data_loader_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ """Initializes the CharacterModel."""
+
+ super().__init__(
+ network_fn,
+ data_loader_args,
+ network_args,
+ metrics,
+ criterion,
+ optimizer,
+ device,
+ )
+ self.emnist_mapping = self.mapping()
+ self.eval()
+
+ def mapping(self) -> Dict:
+ """Mapping between integers and classes."""
+ mapping = load_emnist_mapping()
+ return mapping
+
+ def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]:
+ """Character prediction on an image.
+
+ Args:
+ image (np.ndarray): An image containing a character.
+
+ Returns:
+ Tuple[str, float]: The predicted character and the confidence in the prediction.
+
+ """
+ if image.dtype == np.uint8:
+ image = (image / 255).astype(np.float32)
+
+ # Conver to Pytorch Tensor.
+ image = ToTensor(image)
+
+ prediction = self.network(image)
+ index = torch.argmax(prediction, dim=1)
+ confidence_of_prediction = prediction[index]
+ predicted_character = self.emnist_mapping[index]
+
+ return predicted_character, confidence_of_prediction
diff --git a/src/text_recognizer/models/util.py b/src/text_recognizer/models/util.py
new file mode 100644
index 0000000..905fe7b
--- /dev/null
+++ b/src/text_recognizer/models/util.py
@@ -0,0 +1,19 @@
+"""Utility functions for models."""
+
+import torch
+
+
+def accuracy(outputs: torch.Tensor, labels: torch.Tensro) -> float:
+ """Short summary.
+
+ Args:
+ outputs (torch.Tensor): The output from the network.
+ labels (torch.Tensor): Ground truth labels.
+
+ Returns:
+ float: The accuracy for the batch.
+
+ """
+ _, predicted = torch.max(outputs.data, dim=1)
+ acc = (predicted == labels).sum().item() / labels.shape[0]
+ return acc
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py
new file mode 100644
index 0000000..71d247f
--- /dev/null
+++ b/src/text_recognizer/networks/lenet.py
@@ -0,0 +1,93 @@
+"""Defines the LeNet network."""
+from typing import Callable, Optional, Tuple
+
+import torch
+from torch import nn
+
+
+class Flatten(nn.Module):
+ """Flattens a tensor."""
+
+ def forward(self, x: int) -> torch.Tensor:
+ """Flattens a tensor for input to a nn.Linear layer."""
+ return torch.flatten(x, start_dim=1)
+
+
+class LeNet(nn.Module):
+ """LeNet network."""
+
+ def __init__(
+ self,
+ channels: Tuple[int, ...],
+ kernel_sizes: Tuple[int, ...],
+ hidden_size: Tuple[int, ...],
+ dropout_rate: float,
+ output_size: int,
+ activation_fn: Optional[Callable] = None,
+ ) -> None:
+ """The LeNet network.
+
+ Args:
+ channels (Tuple[int, ...]): Channels in the convolutional layers.
+ kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers.
+ hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers.
+ dropout_rate (float): The dropout rate.
+ output_size (int): Number of classes.
+ activation_fn (Optional[Callable]): The non-linear activation function. Defaults to
+ nn.ReLU(inplace).
+
+ """
+ super().__init__()
+
+ if activation_fn is None:
+ activation_fn = nn.ReLU(inplace=True)
+
+ self.layers = [
+ nn.Conv2d(
+ in_channels=channels[0],
+ out_channels=channels[1],
+ kernel_size=kernel_sizes[0],
+ ),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=channels[1],
+ out_channels=channels[2],
+ kernel_size=kernel_sizes[1],
+ ),
+ activation_fn,
+ nn.MaxPool2d(kernel_sizes[2]),
+ nn.Dropout(p=dropout_rate),
+ Flatten(),
+ nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]),
+ activation_fn,
+ nn.Dropout(p=dropout_rate),
+ nn.Linear(in_features=hidden_size[1], out_features=output_size),
+ ]
+
+ self.layers = nn.Sequential(*self.layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """The feedforward."""
+ return self.layers(x)
+
+
+# def test():
+# x = torch.randn([1, 1, 28, 28])
+# channels = [1, 32, 64]
+# kernel_sizes = [3, 3, 2]
+# hidden_size = [9216, 128]
+# output_size = 10
+# dropout_rate = 0.2
+# activation_fn = nn.ReLU()
+# net = LeNet(
+# channels=channels,
+# kernel_sizes=kernel_sizes,
+# dropout_rate=dropout_rate,
+# hidden_size=hidden_size,
+# output_size=output_size,
+# activation_fn=activation_fn,
+# )
+# from torchsummary import summary
+#
+# summary(net, (1, 28, 28), device="cpu")
+# out = net(x)
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py
new file mode 100644
index 0000000..2a41790
--- /dev/null
+++ b/src/text_recognizer/networks/mlp.py
@@ -0,0 +1,81 @@
+"""Defines the MLP network."""
+from typing import Callable, Optional
+
+import torch
+from torch import nn
+
+
+class MLP(nn.Module):
+ """Multi layered perceptron network."""
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ hidden_size: int,
+ num_layers: int,
+ dropout_rate: float,
+ activation_fn: Optional[Callable] = None,
+ ) -> None:
+ """Initialization of the MLP network.
+
+ Args:
+ input_size (int): The input shape of the network.
+ output_size (int): Number of classes in the dataset.
+ hidden_size (int): The number of `neurons` in each hidden layer.
+ num_layers (int): The number of hidden layers.
+ dropout_rate (float): The dropout rate at each layer.
+ activation_fn (Optional[Callable]): The activation function in the hidden layers, (default:
+ nn.ReLU()).
+
+ """
+ super().__init__()
+
+ if activation_fn is None:
+ activation_fn = nn.ReLU(inplace=True)
+
+ self.layers = [
+ nn.Linear(in_features=input_size, out_features=hidden_size),
+ activation_fn,
+ ]
+
+ for _ in range(num_layers):
+ self.layers += [
+ nn.Linear(in_features=hidden_size, out_features=hidden_size),
+ activation_fn,
+ ]
+
+ if dropout_rate:
+ self.layers.append(nn.Dropout(p=dropout_rate))
+
+ self.layers.append(nn.Linear(in_features=hidden_size, out_features=output_size))
+
+ self.layers = nn.Sequential(*self.layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """The feedforward."""
+ x = torch.flatten(x, start_dim=1)
+ return self.layers(x)
+
+
+# def test():
+# x = torch.randn([1, 28, 28])
+# input_size = torch.flatten(x).shape[0]
+# output_size = 10
+# hidden_size = 128
+# num_layers = 5
+# dropout_rate = 0.25
+# activation_fn = nn.GELU()
+# net = MLP(
+# input_size=input_size,
+# output_size=output_size,
+# hidden_size=hidden_size,
+# num_layers=num_layers,
+# dropout_rate=dropout_rate,
+# activation_fn=activation_fn,
+# )
+# from torchsummary import summary
+#
+# summary(net, (1, 28, 28), device="cpu")
+#
+# out = net(x)
diff --git a/src/text_recognizer/tests/__init__.py b/src/text_recognizer/tests/__init__.py
new file mode 100644
index 0000000..18ff212
--- /dev/null
+++ b/src/text_recognizer/tests/__init__.py
@@ -0,0 +1 @@
+"""Test modules for the text text recognizer."""
diff --git a/src/text_recognizer/tests/support/__init__.py b/src/text_recognizer/tests/support/__init__.py
new file mode 100644
index 0000000..a265ede
--- /dev/null
+++ b/src/text_recognizer/tests/support/__init__.py
@@ -0,0 +1,2 @@
+"""Support file modules."""
+from .create_emnist_support_files import create_emnist_support_files
diff --git a/src/text_recognizer/tests/support/create_emnist_support_files.py b/src/text_recognizer/tests/support/create_emnist_support_files.py
new file mode 100644
index 0000000..5dd1a81
--- /dev/null
+++ b/src/text_recognizer/tests/support/create_emnist_support_files.py
@@ -0,0 +1,33 @@
+"""Module for creating EMNIST test support files."""
+from pathlib import Path
+import shutil
+
+from text_recognizer.datasets.emnist_dataset import (
+ fetch_emnist_dataset,
+ load_emnist_mapping,
+)
+from text_recognizer.util import write_image
+
+SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist"
+
+
+def create_emnist_support_files() -> None:
+ """Create support images for test of CharacterPredictor class."""
+ shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True)
+ SUPPORT_DIRNAME.mkdir()
+
+ dataset = fetch_emnist_dataset(split="byclass", train=False)
+ mapping = load_emnist_mapping()
+
+ for index in [5, 7, 9]:
+ image, label = dataset[index]
+ if len(image.shape) == 3:
+ image = image.squeeze(0)
+ image = image.numpy()
+ label = mapping[int(label)]
+ print(index, label)
+ write_image(image, str(SUPPORT_DIRNAME / f"{label}.png"))
+
+
+if __name__ == "__main__":
+ create_emnist_support_files()
diff --git a/src/text_recognizer/tests/support/emnist/8.png b/src/text_recognizer/tests/support/emnist/8.png
new file mode 100644
index 0000000..faa29aa
--- /dev/null
+++ b/src/text_recognizer/tests/support/emnist/8.png
Binary files differ
diff --git a/src/text_recognizer/tests/support/emnist/U.png b/src/text_recognizer/tests/support/emnist/U.png
new file mode 100644
index 0000000..304eaec
--- /dev/null
+++ b/src/text_recognizer/tests/support/emnist/U.png
Binary files differ
diff --git a/src/text_recognizer/tests/support/emnist/e.png b/src/text_recognizer/tests/support/emnist/e.png
new file mode 100644
index 0000000..a03ecd4
--- /dev/null
+++ b/src/text_recognizer/tests/support/emnist/e.png
Binary files differ
diff --git a/src/text_recognizer/tests/test_character_predictor.py b/src/text_recognizer/tests/test_character_predictor.py
new file mode 100644
index 0000000..7c094ef
--- /dev/null
+++ b/src/text_recognizer/tests/test_character_predictor.py
@@ -0,0 +1,26 @@
+"""Test for CharacterPredictor class."""
+import os
+from pathlib import Path
+import unittest
+
+from text_recognizer.character_predictor import CharacterPredictor
+
+SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "emnist"
+
+os.environ["CUDA_VISIBLE_DEVICES"] = ""
+
+
+class TestCharacterPredictor(unittest.TestCase):
+ """Tests for the CharacterPredictor class."""
+
+ def test_filename(self) -> None:
+ """Test that CharacterPredictor correctly predicts on a single image, for serveral test images."""
+ predictor = CharacterPredictor()
+
+ for filename in SUPPORT_DIRNAME.glob("*.png"):
+ pred, conf = predictor.predict(str(filename))
+ print(
+ f"Prediction: {pred} at confidence: {conf} for image with character {filename.stem}"
+ )
+ self.assertEqual(pred, filename.stem)
+ self.assertGreater(conf, 0.7)
diff --git a/src/text_recognizer/util.py b/src/text_recognizer/util.py
new file mode 100644
index 0000000..52fa1e4
--- /dev/null
+++ b/src/text_recognizer/util.py
@@ -0,0 +1,51 @@
+"""Utility functions for text_recognizer module."""
+import os
+from pathlib import Path
+from typing import Union
+from urllib.request import urlopen
+
+import cv2
+import numpy as np
+
+
+def read_image(image_uri: Union[Path, str], grayscale: bool = False) -> np.ndarray:
+ """Read image_uri."""
+
+ def read_image_from_filename(image_filename: str, imread_flag: int) -> np.ndarray:
+ return cv2.imread(str(image_filename), imread_flag)
+
+ def read_image_from_url(image_url: str, imread_flag: int) -> np.ndarray:
+ if image_url.lower().startswith("http"):
+ url_response = urlopen(str(image_url))
+ image_array = np.array(bytearray(url_response.read()), dtype=np.uint8)
+ return cv2.imdecode(image_array, imread_flag)
+ else:
+ raise ValueError(
+ "Url does not start with http, therfore not safe to open..."
+ ) from None
+
+ imread_flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
+ local_file = os.path.exsits(image_uri)
+ try:
+ image = None
+ if local_file:
+ image = read_image_from_filename(image_uri, imread_flag)
+ else:
+ image = read_image_from_url(image_uri, imread_flag)
+ assert image is not None
+ except Exception as e:
+ raise ValueError(f"Could not load image at {image_uri}: {e}")
+ return image
+
+
+def rescale_image(image: np.ndarray) -> np.ndarray:
+ """Rescale image from [0, 1] to [0, 255]."""
+ if image.max() <= 1.0:
+ image = 255 * (image - image.min()) / (image.max() - image.min())
+ return image
+
+
+def write_image(image: np.ndarray, filename: Union[Path, str]) -> None:
+ """Write image to file."""
+ image = rescale_image(image)
+ cv2.imwrite(str(filename), image)