1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
|
"""Defines the CharacterModel class."""
from typing import Callable, Dict, Optional, Tuple, Type, Union
import numpy as np
import torch
from torch import nn
from torchvision.transforms import ToTensor
from text_recognizer.models.base import Model
class CharacterModel(Model):
"""Model for predicting characters from images."""
def __init__(
self,
network_fn: Type[nn.Module],
network_args: Optional[Dict] = None,
data_loader_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
criterion_args: Optional[Dict] = None,
optimizer: Optional[Callable] = None,
optimizer_args: Optional[Dict] = None,
lr_scheduler: Optional[Callable] = None,
lr_scheduler_args: Optional[Dict] = None,
device: Optional[str] = None,
) -> None:
"""Initializes the CharacterModel."""
super().__init__(
network_fn,
network_args,
data_loader_args,
metrics,
criterion,
criterion_args,
optimizer,
optimizer_args,
lr_scheduler,
lr_scheduler_args,
device,
)
self.tensor_transform = ToTensor()
self.softmax = nn.Softmax(dim=0)
def predict_on_image(
self, image: Union[np.ndarray, torch.Tensor]
) -> Tuple[str, float]:
"""Character prediction on an image.
Args:
image (Union[np.ndarray, torch.Tensor]): An image containing a character.
Returns:
Tuple[str, float]: The predicted character and the confidence in the prediction.
"""
if image.dtype == np.uint8:
# Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
image = self.tensor_transform(image)
if image.dtype == torch.uint8:
# If the image is an unscaled tensor.
image = image.type("torch.FloatTensor") / 255
with torch.no_grad():
# Put the image tensor on the device the model weights are on.
image = image.to(self.device)
logits = self.network(image)
prediction = self.softmax(logits.data.squeeze())
index = int(torch.argmax(prediction, dim=0))
confidence_of_prediction = prediction[index]
predicted_character = self.mapper(index)
return predicted_character, confidence_of_prediction
|