diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-03 23:33:34 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-03 23:33:34 +0200 |
commit | 07dd14116fe1d8148fb614b160245287533620fc (patch) | |
tree | 63395d88b17a14ad453c52889fcf541e6cbbdd3e /src/text_recognizer/models/character_model.py | |
parent | 704451318eb6b0b600ab314cb5aabfac82416bda (diff) |
Working Emnist lines dataset.
Diffstat (limited to 'src/text_recognizer/models/character_model.py')
-rw-r--r-- | src/text_recognizer/models/character_model.py | 32 |
1 files changed, 21 insertions, 11 deletions
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 527fc7d..f1dabb7 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -1,12 +1,15 @@ """Defines the CharacterModel class.""" -from typing import Callable, Dict, Optional, Tuple, Type +from typing import Callable, Dict, Optional, Tuple, Type, Union import numpy as np import torch from torch import nn from torchvision.transforms import ToTensor -from text_recognizer.datasets.emnist_dataset import load_emnist_mapping +from text_recognizer.datasets.emnist_dataset import ( + _augment_emnist_mapping, + _load_emnist_essentials, +) from text_recognizer.models.base import Model @@ -16,7 +19,7 @@ class CharacterModel(Model): def __init__( self, network_fn: Type[nn.Module], - network_args: Dict, + network_args: Optional[Dict] = None, data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, metrics: Optional[Dict] = None, @@ -44,19 +47,23 @@ class CharacterModel(Model): lr_scheduler_args, device, ) - self.load_mapping() + if self.mapping is None: + self.load_mapping() self.tensor_transform = ToTensor() self.softmax = nn.Softmax(dim=0) def load_mapping(self) -> None: """Mapping between integers and classes.""" - self._mapping = load_emnist_mapping() + essentials = _load_emnist_essentials() + self._mapping = _augment_emnist_mapping(dict(essentials["mapping"])) - def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]: + def predict_on_image( + self, image: Union[np.ndarray, torch.Tensor] + ) -> Tuple[str, float]: """Character prediction on an image. Args: - image (np.ndarray): An image containing a character. + image (Union[np.ndarray, torch.Tensor]): An image containing a character. Returns: Tuple[str, float]: The predicted character and the confidence in the prediction. @@ -64,12 +71,15 @@ class CharacterModel(Model): """ if image.dtype == np.uint8: - image = (image / 255).astype(np.float32) - - # Conver to Pytorch Tensor. - image = self.tensor_transform(image) + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 with torch.no_grad(): + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) logits = self.network(image) prediction = self.softmax(logits.data.squeeze()) |