From 7c4de6d88664d2ea1b084f316a11896dde3e1150 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Tue, 23 Jun 2020 22:39:54 +0200 Subject: latest --- src/text_recognizer/models/character_model.py | 71 +++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 src/text_recognizer/models/character_model.py (limited to 'src/text_recognizer/models/character_model.py') diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py new file mode 100644 index 0000000..1570344 --- /dev/null +++ b/src/text_recognizer/models/character_model.py @@ -0,0 +1,71 @@ +"""Defines the CharacterModel class.""" +from typing import Callable, Dict, Optional, Tuple + +import numpy as np +import torch +from torch import nn +from torchvision.transforms import ToTensor + +from text_recognizer.datasets.emnist_dataset import load_emnist_mapping +from text_recognizer.models.base import Model +from text_recognizer.networks.mlp import mlp + + +class CharacterModel(Model): + """Model for predicting characters from images.""" + + def __init__( + self, + network_fn: Callable, + network_args: Dict, + data_loader_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + """Initializes the CharacterModel.""" + + super().__init__( + network_fn, + data_loader_args, + network_args, + metrics, + criterion, + optimizer, + device, + ) + self.emnist_mapping = self.mapping() + self.eval() + + def mapping(self) -> Dict: + """Mapping between integers and classes.""" + mapping = load_emnist_mapping() + return mapping + + def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]: + """Character prediction on an image. + + Args: + image (np.ndarray): An image containing a character. + + Returns: + Tuple[str, float]: The predicted character and the confidence in the prediction. + + """ + if image.dtype == np.uint8: + image = (image / 255).astype(np.float32) + + # Conver to Pytorch Tensor. + image = ToTensor(image) + + prediction = self.network(image) + index = torch.argmax(prediction, dim=1) + confidence_of_prediction = prediction[index] + predicted_character = self.emnist_mapping[index] + + return predicted_character, confidence_of_prediction -- cgit v1.2.3-70-g09d2