summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/character_model.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-06-23 22:39:54 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-06-23 22:39:54 +0200
commit7c4de6d88664d2ea1b084f316a11896dde3e1150 (patch)
treecbde7e64c8064e9b523dfb65cd6c487d061ec805 /src/text_recognizer/models/character_model.py
parenta7a9ce59fc37317dd74c3a52caf7c4e68e570ee8 (diff)
latest
Diffstat (limited to 'src/text_recognizer/models/character_model.py')
-rw-r--r--src/text_recognizer/models/character_model.py71
1 files changed, 71 insertions, 0 deletions
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
new file mode 100644
index 0000000..1570344
--- /dev/null
+++ b/src/text_recognizer/models/character_model.py
@@ -0,0 +1,71 @@
+"""Defines the CharacterModel class."""
+from typing import Callable, Dict, Optional, Tuple
+
+import numpy as np
+import torch
+from torch import nn
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets.emnist_dataset import load_emnist_mapping
+from text_recognizer.models.base import Model
+from text_recognizer.networks.mlp import mlp
+
+
+class CharacterModel(Model):
+ """Model for predicting characters from images."""
+
+ def __init__(
+ self,
+ network_fn: Callable,
+ network_args: Dict,
+ data_loader_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ """Initializes the CharacterModel."""
+
+ super().__init__(
+ network_fn,
+ data_loader_args,
+ network_args,
+ metrics,
+ criterion,
+ optimizer,
+ device,
+ )
+ self.emnist_mapping = self.mapping()
+ self.eval()
+
+ def mapping(self) -> Dict:
+ """Mapping between integers and classes."""
+ mapping = load_emnist_mapping()
+ return mapping
+
+ def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]:
+ """Character prediction on an image.
+
+ Args:
+ image (np.ndarray): An image containing a character.
+
+ Returns:
+ Tuple[str, float]: The predicted character and the confidence in the prediction.
+
+ """
+ if image.dtype == np.uint8:
+ image = (image / 255).astype(np.float32)
+
+ # Conver to Pytorch Tensor.
+ image = ToTensor(image)
+
+ prediction = self.network(image)
+ index = torch.argmax(prediction, dim=1)
+ confidence_of_prediction = prediction[index]
+ predicted_character = self.emnist_mapping[index]
+
+ return predicted_character, confidence_of_prediction