summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/character_model.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-07-22 23:18:08 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-07-22 23:18:08 +0200
commitf473456c19558aaf8552df97a51d4e18cc69dfa8 (patch)
tree0d35ce2410ff623ba5fb433d616d95b67ecf7a98 /src/text_recognizer/models/character_model.py
parentad3bd52530f4800d4fb05dfef3354921f95513af (diff)
Working training loop and testing of trained CharacterModel.
Diffstat (limited to 'src/text_recognizer/models/character_model.py')
-rw-r--r--src/text_recognizer/models/character_model.py30
1 files changed, 18 insertions, 12 deletions
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index fd69bf2..527fc7d 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -1,5 +1,5 @@
"""Defines the CharacterModel class."""
-from typing import Callable, Dict, Optional, Tuple
+from typing import Callable, Dict, Optional, Tuple, Type
import numpy as np
import torch
@@ -8,7 +8,6 @@ from torchvision.transforms import ToTensor
from text_recognizer.datasets.emnist_dataset import load_emnist_mapping
from text_recognizer.models.base import Model
-from text_recognizer.networks.mlp import mlp
class CharacterModel(Model):
@@ -16,8 +15,9 @@ class CharacterModel(Model):
def __init__(
self,
- network_fn: Callable,
+ network_fn: Type[nn.Module],
network_args: Dict,
+ data_loader: Optional[Callable] = None,
data_loader_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
@@ -33,6 +33,7 @@ class CharacterModel(Model):
super().__init__(
network_fn,
network_args,
+ data_loader,
data_loader_args,
metrics,
criterion,
@@ -43,13 +44,13 @@ class CharacterModel(Model):
lr_scheduler_args,
device,
)
- self.emnist_mapping = self.mapping()
- self.eval()
+ self.load_mapping()
+ self.tensor_transform = ToTensor()
+ self.softmax = nn.Softmax(dim=0)
- def mapping(self) -> Dict[int, str]:
+ def load_mapping(self) -> None:
"""Mapping between integers and classes."""
- mapping = load_emnist_mapping()
- return mapping
+ self._mapping = load_emnist_mapping()
def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]:
"""Character prediction on an image.
@@ -61,15 +62,20 @@ class CharacterModel(Model):
Tuple[str, float]: The predicted character and the confidence in the prediction.
"""
+
if image.dtype == np.uint8:
image = (image / 255).astype(np.float32)
# Conver to Pytorch Tensor.
- image = ToTensor(image)
+ image = self.tensor_transform(image)
+
+ with torch.no_grad():
+ logits = self.network(image)
+
+ prediction = self.softmax(logits.data.squeeze())
- prediction = self.network(image)
- index = torch.argmax(prediction, dim=1)
+ index = int(torch.argmax(prediction, dim=0))
confidence_of_prediction = prediction[index]
- predicted_character = self.emnist_mapping[index]
+ predicted_character = self._mapping[index]
return predicted_character, confidence_of_prediction