summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/character_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models/character_model.py')
-rw-r--r--src/text_recognizer/models/character_model.py32
1 files changed, 21 insertions, 11 deletions
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 527fc7d..f1dabb7 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -1,12 +1,15 @@
"""Defines the CharacterModel class."""
-from typing import Callable, Dict, Optional, Tuple, Type
+from typing import Callable, Dict, Optional, Tuple, Type, Union
import numpy as np
import torch
from torch import nn
from torchvision.transforms import ToTensor
-from text_recognizer.datasets.emnist_dataset import load_emnist_mapping
+from text_recognizer.datasets.emnist_dataset import (
+ _augment_emnist_mapping,
+ _load_emnist_essentials,
+)
from text_recognizer.models.base import Model
@@ -16,7 +19,7 @@ class CharacterModel(Model):
def __init__(
self,
network_fn: Type[nn.Module],
- network_args: Dict,
+ network_args: Optional[Dict] = None,
data_loader: Optional[Callable] = None,
data_loader_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
@@ -44,19 +47,23 @@ class CharacterModel(Model):
lr_scheduler_args,
device,
)
- self.load_mapping()
+ if self.mapping is None:
+ self.load_mapping()
self.tensor_transform = ToTensor()
self.softmax = nn.Softmax(dim=0)
def load_mapping(self) -> None:
"""Mapping between integers and classes."""
- self._mapping = load_emnist_mapping()
+ essentials = _load_emnist_essentials()
+ self._mapping = _augment_emnist_mapping(dict(essentials["mapping"]))
- def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]:
+ def predict_on_image(
+ self, image: Union[np.ndarray, torch.Tensor]
+ ) -> Tuple[str, float]:
"""Character prediction on an image.
Args:
- image (np.ndarray): An image containing a character.
+ image (Union[np.ndarray, torch.Tensor]): An image containing a character.
Returns:
Tuple[str, float]: The predicted character and the confidence in the prediction.
@@ -64,12 +71,15 @@ class CharacterModel(Model):
"""
if image.dtype == np.uint8:
- image = (image / 255).astype(np.float32)
-
- # Conver to Pytorch Tensor.
- image = self.tensor_transform(image)
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
with torch.no_grad():
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
logits = self.network(image)
prediction = self.softmax(logits.data.squeeze())