From 7c4de6d88664d2ea1b084f316a11896dde3e1150 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Tue, 23 Jun 2020 22:39:54 +0200 Subject: latest --- src/text_recognizer/models/__init__.py | 1 + src/text_recognizer/models/base.py | 230 ++++++++++++++++++++++++++ src/text_recognizer/models/character_model.py | 71 ++++++++ src/text_recognizer/models/util.py | 19 +++ 4 files changed, 321 insertions(+) create mode 100644 src/text_recognizer/models/base.py create mode 100644 src/text_recognizer/models/character_model.py create mode 100644 src/text_recognizer/models/util.py (limited to 'src/text_recognizer/models') diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index aa26de6..d265dcf 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -1 +1,2 @@ """Model modules.""" +from .character_model import CharacterModel diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py new file mode 100644 index 0000000..736af7b --- /dev/null +++ b/src/text_recognizer/models/base.py @@ -0,0 +1,230 @@ +"""Abstract Model class for PyTorch neural networks.""" + +from abc import ABC, abstractmethod +from pathlib import Path +import shutil +from typing import Callable, Dict, Optional, Tuple + +from loguru import logger +import torch +from torch import nn +from torchsummary import summary + +from text_recognizer.dataset.data_loader import fetch_data_loader + +WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" + + +class Model(ABC): + """Abstract Model class with composition of different parts defining a PyTorch neural network.""" + + def __init__( + self, + network_fn: Callable, + network_args: Dict, + data_loader_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + """Base class, to be inherited by predictors for specific type of data. + + Args: + network_fn (Callable): The PyTorch network. + network_args (Dict): Arguments for the network. + data_loader_args (Optional[Dict]): Arguments for the data loader. + metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. + criterion (Optional[Callable]): The criterion to evaulate the preformance of the network. + Defaults to None. + criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None. + optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None. + optimizer_args (Optional[Dict]): Dict of arguments for optimizer. Defaults to None. + lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None. + lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to + None. + device (Optional[str]): Name of the device to train on. Defaults to None. + + """ + + # Fetch data loaders. + if data_loader_args is not None: + self._data_loaders = fetch_data_loader(**data_loader_args) + dataset_name = self._data_loaders.items()[0].dataset.__name__ + else: + dataset_name = "" + self._data_loaders = None + + self.name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}" + + # Extract the input shape for the torchsummary. + self._input_shape = network_args.pop("input_shape") + + if metrics is not None: + self._metrics = metrics + + # Set the device. + if self.device is None: + self._device = torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu" + ) + else: + self._device = device + + # Load network. + self._network = network_fn(**network_args) + + # To device. + self._network.to(self._device) + + # Set criterion. + self._criterion = None + if criterion is not None: + self._criterion = criterion(**criterion_args) + + # Set optimizer. + self._optimizer = None + if optimizer is not None: + self._optimizer = optimizer(self._network.parameters(), **optimizer_args) + + # Set learning rate scheduler. + self._lr_scheduler = None + if lr_scheduler is not None: + self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args) + + @property + def input_shape(self) -> Tuple[int, ...]: + """The input shape.""" + return self._input_shape + + def eval(self) -> None: + """Sets the network to evaluation mode.""" + self._network.eval() + + def train(self) -> None: + """Sets the network to train mode.""" + self._network.train() + + @property + def device(self) -> str: + """Device where the weights are stored, i.e. cpu or cuda.""" + return self._device + + @property + def metrics(self) -> Optional[Dict]: + """Metrics.""" + return self._metrics + + @property + def criterion(self) -> Optional[Callable]: + """Criterion.""" + return self._criterion + + @property + def optimizer(self) -> Optional[Callable]: + """Optimizer.""" + return self._optimizer + + @property + def lr_scheduler(self) -> Optional[Callable]: + """Learning rate scheduler.""" + return self._lr_scheduler + + @property + def data_loaders(self) -> Optional[Dict]: + """Dataloaders.""" + return self._data_loaders + + @property + def network(self) -> nn.Module: + """Neural network.""" + return self._network + + @property + def weights_filename(self) -> str: + """Filepath to the network weights.""" + WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True) + return str(WEIGHT_DIRNAME / f"{self.name}_weights.pt") + + def summary(self) -> None: + """Prints a summary of the network architecture.""" + summary(self._network, self._input_shape, device=self.device) + + def _get_state(self) -> Dict: + """Get the state dict of the model.""" + state = {"model_state": self._network.state_dict()} + if self._optimizer is not None: + state["optimizer_state"] = self._optimizer.state_dict() + return state + + def load_checkpoint(self, path: Path) -> int: + """Load a previously saved checkpoint. + + Args: + path (Path): Path to the experiment with the checkpoint. + + Returns: + epoch (int): The last epoch when the checkpoint was created. + + """ + if not path.exists(): + logger.debug("File does not exist {str(path)}") + + checkpoint = torch.load(str(path)) + self._network.load_state_dict(checkpoint["model_state"]) + + if self._optimizer is not None: + self._optimizer.load_state_dict(checkpoint["optimizer_state"]) + + epoch = checkpoint["epoch"] + + return epoch + + def save_checkpoint( + self, path: Path, is_best: bool, epoch: int, val_metric: str + ) -> None: + """Saves a checkpoint of the model. + + Args: + path (Path): Path to the experiment folder. + is_best (bool): If it is the currently best model. + epoch (int): The epoch of the checkpoint. + val_metric (str): Validation metric. + + """ + state = self._get_state_dict() + state["is_best"] = is_best + state["epoch"] = epoch + + path.mkdir(parents=True, exist_ok=True) + + logger.debug("Saving checkpoint...") + filepath = str(path / "last.pt") + torch.save(state, filepath) + + if is_best: + logger.debug( + f"Found a new best {val_metric}. Saving best checkpoint and weights." + ) + self.save_weights() + shutil.copyfile(filepath, str(path / "best.pt")) + + def load_weights(self) -> None: + """Load the network weights.""" + logger.debug("Loading network weights.") + weights = torch.load(self.weights_filename)["model_state"] + self._network.load_state_dict(weights) + + def save_weights(self) -> None: + """Save the network weights.""" + logger.debug("Saving network weights.") + torch.save({"model_state": self._network.state_dict()}, self.weights_filename) + + @abstractmethod + def mapping(self) -> Dict: + """Mapping from network output to class.""" + ... diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py new file mode 100644 index 0000000..1570344 --- /dev/null +++ b/src/text_recognizer/models/character_model.py @@ -0,0 +1,71 @@ +"""Defines the CharacterModel class.""" +from typing import Callable, Dict, Optional, Tuple + +import numpy as np +import torch +from torch import nn +from torchvision.transforms import ToTensor + +from text_recognizer.datasets.emnist_dataset import load_emnist_mapping +from text_recognizer.models.base import Model +from text_recognizer.networks.mlp import mlp + + +class CharacterModel(Model): + """Model for predicting characters from images.""" + + def __init__( + self, + network_fn: Callable, + network_args: Dict, + data_loader_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + """Initializes the CharacterModel.""" + + super().__init__( + network_fn, + data_loader_args, + network_args, + metrics, + criterion, + optimizer, + device, + ) + self.emnist_mapping = self.mapping() + self.eval() + + def mapping(self) -> Dict: + """Mapping between integers and classes.""" + mapping = load_emnist_mapping() + return mapping + + def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]: + """Character prediction on an image. + + Args: + image (np.ndarray): An image containing a character. + + Returns: + Tuple[str, float]: The predicted character and the confidence in the prediction. + + """ + if image.dtype == np.uint8: + image = (image / 255).astype(np.float32) + + # Conver to Pytorch Tensor. + image = ToTensor(image) + + prediction = self.network(image) + index = torch.argmax(prediction, dim=1) + confidence_of_prediction = prediction[index] + predicted_character = self.emnist_mapping[index] + + return predicted_character, confidence_of_prediction diff --git a/src/text_recognizer/models/util.py b/src/text_recognizer/models/util.py new file mode 100644 index 0000000..905fe7b --- /dev/null +++ b/src/text_recognizer/models/util.py @@ -0,0 +1,19 @@ +"""Utility functions for models.""" + +import torch + + +def accuracy(outputs: torch.Tensor, labels: torch.Tensro) -> float: + """Short summary. + + Args: + outputs (torch.Tensor): The output from the network. + labels (torch.Tensor): Ground truth labels. + + Returns: + float: The accuracy for the batch. + + """ + _, predicted = torch.max(outputs.data, dim=1) + acc = (predicted == labels).sum().item() / labels.shape[0] + return acc -- cgit v1.2.3-70-g09d2