diff options
Diffstat (limited to 'src/text_recognizer')
-rw-r--r-- | src/text_recognizer/character_predictor.py | 8 | ||||
-rw-r--r-- | src/text_recognizer/datasets/__init__.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 9 | ||||
-rw-r--r-- | src/text_recognizer/models/__init__.py | 4 | ||||
-rw-r--r-- | src/text_recognizer/models/base.py | 66 | ||||
-rw-r--r-- | src/text_recognizer/models/character_model.py | 30 | ||||
-rw-r--r-- | src/text_recognizer/models/metrics.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/networks/__init__.py | 4 | ||||
-rw-r--r-- | src/text_recognizer/networks/lenet.py | 55 | ||||
-rw-r--r-- | src/text_recognizer/networks/mlp.py | 71 | ||||
-rw-r--r-- | src/text_recognizer/tests/test_character_predictor.py | 19 | ||||
-rw-r--r-- | src/text_recognizer/util.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt | bin | 0 -> 14483400 bytes | |||
-rw-r--r-- | src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt | bin | 0 -> 1702233 bytes |
14 files changed, 155 insertions, 117 deletions
diff --git a/src/text_recognizer/character_predictor.py b/src/text_recognizer/character_predictor.py index 69ef896..a773f36 100644 --- a/src/text_recognizer/character_predictor.py +++ b/src/text_recognizer/character_predictor.py @@ -1,8 +1,8 @@ """CharacterPredictor class.""" - -from typing import Tuple, Union +from typing import Dict, Tuple, Type, Union import numpy as np +from torch import nn from text_recognizer.models import CharacterModel from text_recognizer.util import read_image @@ -11,9 +11,9 @@ from text_recognizer.util import read_image class CharacterPredictor: """Recognizes the character in handwritten character images.""" - def __init__(self) -> None: + def __init__(self, network_fn: Type[nn.Module], network_args: Dict) -> None: """Intializes the CharacterModel and load the pretrained weights.""" - self.model = CharacterModel() + self.model = CharacterModel(network_fn=network_fn, network_args=network_args) self.model.load_weights() self.model.eval() diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index aec5bf9..795be90 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,2 +1,4 @@ """Dataset modules.""" from .emnist_dataset import EmnistDataLoader + +__all__ = ["EmnistDataLoader"] diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index a17d7a9..b92b57d 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -2,7 +2,7 @@ import json from pathlib import Path -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Type from loguru import logger import numpy as np @@ -102,21 +102,22 @@ class EmnistDataLoader: self.shuffle = shuffle self.num_workers = num_workers self.cuda = cuda + self.seed = seed self._data_loaders = self._fetch_emnist_data_loaders() @property def __name__(self) -> str: """Returns the name of the dataset.""" - return "EMNIST" + return "Emnist" - def __call__(self, split: str) -> Optional[DataLoader]: + def __call__(self, split: str) -> DataLoader: """Returns the `split` DataLoader. Args: split (str): The dataset split, i.e. train or val. Returns: - type: A PyTorch DataLoader. + DataLoader: A PyTorch DataLoader. Raises: ValueError: If the split does not exist. diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index d265dcf..ff10a07 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -1,2 +1,6 @@ """Model modules.""" +from .base import Model from .character_model import CharacterModel +from .metrics import accuracy + +__all__ = ["Model", "CharacterModel", "accuracy"] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index 0cc531a..b78eacb 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -1,9 +1,11 @@ """Abstract Model class for PyTorch neural networks.""" from abc import ABC, abstractmethod +from glob import glob from pathlib import Path +import re import shutil -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple, Type from loguru import logger import torch @@ -19,7 +21,7 @@ class Model(ABC): def __init__( self, - network_fn: Callable, + network_fn: Type[nn.Module], network_args: Dict, data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, @@ -35,7 +37,7 @@ class Model(ABC): """Base class, to be inherited by model for specific type of data. Args: - network_fn (Callable): The PyTorch network. + network_fn (Type[nn.Module]): The PyTorch network. network_args (Dict): Arguments for the network. data_loader (Optional[Callable]): A function that fetches train and val DataLoader. data_loader_args (Optional[Dict]): Arguments for the DataLoader. @@ -57,27 +59,29 @@ class Model(ABC): self._data_loaders = data_loader(**data_loader_args) dataset_name = self._data_loaders.__name__ else: - dataset_name = "" + dataset_name = "*" self._data_loaders = None - self.name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}" + self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}" # Extract the input shape for the torchsummary. - self._input_shape = network_args.pop("input_shape") + if isinstance(network_args["input_size"], int): + self._input_shape = (1,) + tuple([network_args["input_size"]]) + else: + self._input_shape = (1,) + tuple(network_args["input_size"]) if metrics is not None: self._metrics = metrics # Set the device. - if self.device is None: - self._device = torch.device( - "cuda:0" if torch.cuda.is_available() else "cpu" - ) + if device is None: + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self._device = device # Load network. - self._network = network_fn(**network_args) + self.network_args = network_args + self._network = network_fn(**self.network_args) # To device. self._network.to(self._device) @@ -95,13 +99,29 @@ class Model(ABC): # Set learning rate scheduler. self._lr_scheduler = None if lr_scheduler is not None: + # OneCycleLR needs the number of steps in an epoch as an input argument. + if "OneCycleLR" in str(lr_scheduler): + lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders("train")) self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args) + # Class mapping. + self._mapping = None + + @property + def __name__(self) -> str: + """Returns the name of the model.""" + return self._name + @property def input_shape(self) -> Tuple[int, ...]: """The input shape.""" return self._input_shape + @property + def mapping(self) -> Dict: + """Returns the class mapping.""" + return self._mapping + def eval(self) -> None: """Sets the network to evaluation mode.""" self._network.eval() @@ -149,13 +169,14 @@ class Model(ABC): def weights_filename(self) -> str: """Filepath to the network weights.""" WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True) - return str(WEIGHT_DIRNAME / f"{self.name}_weights.pt") + return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt") def summary(self) -> None: """Prints a summary of the network architecture.""" - summary(self._network, self._input_shape, device=self.device) + device = re.sub("[^A-Za-z]+", "", self.device) + summary(self._network, self._input_shape, device=device) - def _get_state(self) -> Dict: + def _get_state_dict(self) -> Dict: """Get the state dict of the model.""" state = {"model_state": self._network.state_dict()} if self._optimizer is not None: @@ -172,6 +193,7 @@ class Model(ABC): epoch (int): The last epoch when the checkpoint was created. """ + logger.debug("Loading checkpoint...") if not path.exists(): logger.debug("File does not exist {str(path)}") @@ -200,6 +222,7 @@ class Model(ABC): state = self._get_state_dict() state["is_best"] = is_best state["epoch"] = epoch + state["network_args"] = self.network_args path.mkdir(parents=True, exist_ok=True) @@ -216,15 +239,18 @@ class Model(ABC): def load_weights(self) -> None: """Load the network weights.""" logger.debug("Loading network weights.") - weights = torch.load(self.weights_filename)["model_state"] + filename = glob(self.weights_filename)[0] + weights = torch.load(filename, map_location=torch.device(self._device))[ + "model_state" + ] self._network.load_state_dict(weights) - def save_weights(self) -> None: + def save_weights(self, path: Path) -> None: """Save the network weights.""" - logger.debug("Saving network weights.") - torch.save({"model_state": self._network.state_dict()}, self.weights_filename) + logger.debug("Saving the best network weights.") + shutil.copyfile(str(path / "best.pt"), self.weights_filename) @abstractmethod - def mapping(self) -> Dict: - """Mapping from network output to class.""" + def load_mapping(self) -> None: + """Loads class mapping from network output to character.""" ... diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index fd69bf2..527fc7d 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -1,5 +1,5 @@ """Defines the CharacterModel class.""" -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple, Type import numpy as np import torch @@ -8,7 +8,6 @@ from torchvision.transforms import ToTensor from text_recognizer.datasets.emnist_dataset import load_emnist_mapping from text_recognizer.models.base import Model -from text_recognizer.networks.mlp import mlp class CharacterModel(Model): @@ -16,8 +15,9 @@ class CharacterModel(Model): def __init__( self, - network_fn: Callable, + network_fn: Type[nn.Module], network_args: Dict, + data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, metrics: Optional[Dict] = None, criterion: Optional[Callable] = None, @@ -33,6 +33,7 @@ class CharacterModel(Model): super().__init__( network_fn, network_args, + data_loader, data_loader_args, metrics, criterion, @@ -43,13 +44,13 @@ class CharacterModel(Model): lr_scheduler_args, device, ) - self.emnist_mapping = self.mapping() - self.eval() + self.load_mapping() + self.tensor_transform = ToTensor() + self.softmax = nn.Softmax(dim=0) - def mapping(self) -> Dict[int, str]: + def load_mapping(self) -> None: """Mapping between integers and classes.""" - mapping = load_emnist_mapping() - return mapping + self._mapping = load_emnist_mapping() def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]: """Character prediction on an image. @@ -61,15 +62,20 @@ class CharacterModel(Model): Tuple[str, float]: The predicted character and the confidence in the prediction. """ + if image.dtype == np.uint8: image = (image / 255).astype(np.float32) # Conver to Pytorch Tensor. - image = ToTensor(image) + image = self.tensor_transform(image) + + with torch.no_grad(): + logits = self.network(image) + + prediction = self.softmax(logits.data.squeeze()) - prediction = self.network(image) - index = torch.argmax(prediction, dim=1) + index = int(torch.argmax(prediction, dim=0)) confidence_of_prediction = prediction[index] - predicted_character = self.emnist_mapping[index] + predicted_character = self._mapping[index] return predicted_character, confidence_of_prediction diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py index e2a30a9..ac8d68e 100644 --- a/src/text_recognizer/models/metrics.py +++ b/src/text_recognizer/models/metrics.py @@ -3,7 +3,7 @@ import torch -def accuracy(outputs: torch.Tensor, labels: torch.Tensro) -> float: +def accuracy(outputs: torch.Tensor, labels: torch.Tensor) -> float: """Computes the accuracy. Args: diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index 4ea5bb3..e6b6946 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -1 +1,5 @@ """Network modules.""" +from .lenet import LeNet +from .mlp import MLP + +__all__ = ["MLP", "LeNet"] diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py index 71d247f..2839a0c 100644 --- a/src/text_recognizer/networks/lenet.py +++ b/src/text_recognizer/networks/lenet.py @@ -1,5 +1,5 @@ """Defines the LeNet network.""" -from typing import Callable, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple import torch from torch import nn @@ -18,28 +18,37 @@ class LeNet(nn.Module): def __init__( self, - channels: Tuple[int, ...], - kernel_sizes: Tuple[int, ...], - hidden_size: Tuple[int, ...], - dropout_rate: float, - output_size: int, + input_size: Tuple[int, ...] = (1, 28, 28), + channels: Tuple[int, ...] = (1, 32, 64), + kernel_sizes: Tuple[int, ...] = (3, 3, 2), + hidden_size: Tuple[int, ...] = (9216, 128), + dropout_rate: float = 0.2, + output_size: int = 10, activation_fn: Optional[Callable] = None, + activation_fn_args: Optional[Dict] = None, ) -> None: """The LeNet network. Args: - channels (Tuple[int, ...]): Channels in the convolutional layers. - kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. + input_size (Tuple[int, ...]): The input shape of the network. Defaults to (1, 28, 28). + channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). + kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers. - dropout_rate (float): The dropout rate. - output_size (int): Number of classes. + Defaults to (9216, 128). + dropout_rate (float): The dropout rate. Defaults to 0.2. + output_size (int): Number of classes. Defaults to 10. activation_fn (Optional[Callable]): The non-linear activation function. Defaults to nn.ReLU(inplace). + activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None. """ super().__init__() - if activation_fn is None: + self._input_size = input_size + + if activation_fn is not None: + activation_fn = getattr(nn, activation_fn)(activation_fn_args) + else: activation_fn = nn.ReLU(inplace=True) self.layers = [ @@ -68,26 +77,6 @@ class LeNet(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """The feedforward.""" + if len(x.shape) == 3: + x = x.unsqueeze(0) return self.layers(x) - - -# def test(): -# x = torch.randn([1, 1, 28, 28]) -# channels = [1, 32, 64] -# kernel_sizes = [3, 3, 2] -# hidden_size = [9216, 128] -# output_size = 10 -# dropout_rate = 0.2 -# activation_fn = nn.ReLU() -# net = LeNet( -# channels=channels, -# kernel_sizes=kernel_sizes, -# dropout_rate=dropout_rate, -# hidden_size=hidden_size, -# output_size=output_size, -# activation_fn=activation_fn, -# ) -# from torchsummary import summary -# -# summary(net, (1, 28, 28), device="cpu") -# out = net(x) diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py index 2a41790..d704d99 100644 --- a/src/text_recognizer/networks/mlp.py +++ b/src/text_recognizer/networks/mlp.py @@ -1,5 +1,5 @@ """Defines the MLP network.""" -from typing import Callable, Optional +from typing import Callable, Dict, List, Optional, Union import torch from torch import nn @@ -10,45 +10,54 @@ class MLP(nn.Module): def __init__( self, - input_size: int, - output_size: int, - hidden_size: int, - num_layers: int, - dropout_rate: float, + input_size: int = 784, + output_size: int = 10, + hidden_size: Union[int, List] = 128, + num_layers: int = 3, + dropout_rate: float = 0.2, activation_fn: Optional[Callable] = None, + activation_fn_args: Optional[Dict] = None, ) -> None: """Initialization of the MLP network. Args: - input_size (int): The input shape of the network. - output_size (int): Number of classes in the dataset. - hidden_size (int): The number of `neurons` in each hidden layer. - num_layers (int): The number of hidden layers. - dropout_rate (float): The dropout rate at each layer. - activation_fn (Optional[Callable]): The activation function in the hidden layers, (default: - nn.ReLU()). + input_size (int): The input shape of the network. Defaults to 784. + output_size (int): Number of classes in the dataset. Defaults to 10. + hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. + num_layers (int): The number of hidden layers. Defaults to 3. + dropout_rate (float): The dropout rate at each layer. Defaults to 0.2. + activation_fn (Optional[Callable]): The activation function in the hidden layers. Defaults to + None. + activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None. """ super().__init__() - if activation_fn is None: + if activation_fn is not None: + activation_fn = getattr(nn, activation_fn)(activation_fn_args) + else: activation_fn = nn.ReLU(inplace=True) + if isinstance(hidden_size, int): + hidden_size = [hidden_size] * num_layers + self.layers = [ - nn.Linear(in_features=input_size, out_features=hidden_size), + nn.Linear(in_features=input_size, out_features=hidden_size[0]), activation_fn, ] - for _ in range(num_layers): + for i in range(num_layers - 1): self.layers += [ - nn.Linear(in_features=hidden_size, out_features=hidden_size), + nn.Linear(in_features=hidden_size[i], out_features=hidden_size[i + 1]), activation_fn, ] if dropout_rate: self.layers.append(nn.Dropout(p=dropout_rate)) - self.layers.append(nn.Linear(in_features=hidden_size, out_features=output_size)) + self.layers.append( + nn.Linear(in_features=hidden_size[-1], out_features=output_size) + ) self.layers = nn.Sequential(*self.layers) @@ -57,25 +66,7 @@ class MLP(nn.Module): x = torch.flatten(x, start_dim=1) return self.layers(x) - -# def test(): -# x = torch.randn([1, 28, 28]) -# input_size = torch.flatten(x).shape[0] -# output_size = 10 -# hidden_size = 128 -# num_layers = 5 -# dropout_rate = 0.25 -# activation_fn = nn.GELU() -# net = MLP( -# input_size=input_size, -# output_size=output_size, -# hidden_size=hidden_size, -# num_layers=num_layers, -# dropout_rate=dropout_rate, -# activation_fn=activation_fn, -# ) -# from torchsummary import summary -# -# summary(net, (1, 28, 28), device="cpu") -# -# out = net(x) + @property + def __name__(self) -> str: + """Returns the name of the network.""" + return "mlp" diff --git a/src/text_recognizer/tests/test_character_predictor.py b/src/text_recognizer/tests/test_character_predictor.py index 7c094ef..c603a3a 100644 --- a/src/text_recognizer/tests/test_character_predictor.py +++ b/src/text_recognizer/tests/test_character_predictor.py @@ -1,9 +1,14 @@ """Test for CharacterPredictor class.""" +import importlib import os from pathlib import Path import unittest +import click +from loguru import logger + from text_recognizer.character_predictor import CharacterPredictor +from text_recognizer.networks import MLP SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "emnist" @@ -13,13 +18,23 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "" class TestCharacterPredictor(unittest.TestCase): """Tests for the CharacterPredictor class.""" + # @click.command() + # @click.option( + # "--network", type=str, help="Network to load, e.g. MLP or LeNet.", default="MLP" + # ) def test_filename(self) -> None: """Test that CharacterPredictor correctly predicts on a single image, for serveral test images.""" - predictor = CharacterPredictor() + network_module = importlib.import_module("text_recognizer.networks") + network_fn_ = getattr(network_module, "MLP") + # network_args = {"input_size": [28, 28], "output_size": 62, "dropout_rate": 0} + network_args = {"input_size": 784, "output_size": 62, "dropout_rate": 0.2} + predictor = CharacterPredictor( + network_fn=network_fn_, network_args=network_args + ) for filename in SUPPORT_DIRNAME.glob("*.png"): pred, conf = predictor.predict(str(filename)) - print( + logger.info( f"Prediction: {pred} at confidence: {conf} for image with character {filename.stem}" ) self.assertEqual(pred, filename.stem) diff --git a/src/text_recognizer/util.py b/src/text_recognizer/util.py index 52fa1e4..6c07c60 100644 --- a/src/text_recognizer/util.py +++ b/src/text_recognizer/util.py @@ -25,7 +25,7 @@ def read_image(image_uri: Union[Path, str], grayscale: bool = False) -> np.ndarr ) from None imread_flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR - local_file = os.path.exsits(image_uri) + local_file = os.path.exists(image_uri) try: image = None if local_file: diff --git a/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt Binary files differnew file mode 100644 index 0000000..43a3891 --- /dev/null +++ b/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt Binary files differnew file mode 100644 index 0000000..0dde787 --- /dev/null +++ b/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt |