From f473456c19558aaf8552df97a51d4e18cc69dfa8 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Wed, 22 Jul 2020 23:18:08 +0200 Subject: Working training loop and testing of trained CharacterModel. --- src/text_recognizer/models/character_model.py | 30 ++++++++++++++++----------- 1 file changed, 18 insertions(+), 12 deletions(-) (limited to 'src/text_recognizer/models/character_model.py') diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index fd69bf2..527fc7d 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -1,5 +1,5 @@ """Defines the CharacterModel class.""" -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple, Type import numpy as np import torch @@ -8,7 +8,6 @@ from torchvision.transforms import ToTensor from text_recognizer.datasets.emnist_dataset import load_emnist_mapping from text_recognizer.models.base import Model -from text_recognizer.networks.mlp import mlp class CharacterModel(Model): @@ -16,8 +15,9 @@ class CharacterModel(Model): def __init__( self, - network_fn: Callable, + network_fn: Type[nn.Module], network_args: Dict, + data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, metrics: Optional[Dict] = None, criterion: Optional[Callable] = None, @@ -33,6 +33,7 @@ class CharacterModel(Model): super().__init__( network_fn, network_args, + data_loader, data_loader_args, metrics, criterion, @@ -43,13 +44,13 @@ class CharacterModel(Model): lr_scheduler_args, device, ) - self.emnist_mapping = self.mapping() - self.eval() + self.load_mapping() + self.tensor_transform = ToTensor() + self.softmax = nn.Softmax(dim=0) - def mapping(self) -> Dict[int, str]: + def load_mapping(self) -> None: """Mapping between integers and classes.""" - mapping = load_emnist_mapping() - return mapping + self._mapping = load_emnist_mapping() def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]: """Character prediction on an image. @@ -61,15 +62,20 @@ class CharacterModel(Model): Tuple[str, float]: The predicted character and the confidence in the prediction. """ + if image.dtype == np.uint8: image = (image / 255).astype(np.float32) # Conver to Pytorch Tensor. - image = ToTensor(image) + image = self.tensor_transform(image) + + with torch.no_grad(): + logits = self.network(image) + + prediction = self.softmax(logits.data.squeeze()) - prediction = self.network(image) - index = torch.argmax(prediction, dim=1) + index = int(torch.argmax(prediction, dim=0)) confidence_of_prediction = prediction[index] - predicted_character = self.emnist_mapping[index] + predicted_character = self._mapping[index] return predicted_character, confidence_of_prediction -- cgit v1.2.3-70-g09d2