summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/character_model.py
blob: 50e94a205a5fb12b0e70ead9d2e9b9c6a743c920 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Defines the CharacterModel class."""
from typing import Callable, Dict, Optional, Tuple, Type, Union

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor

from text_recognizer.datasets import EmnistMapper
from text_recognizer.models.base import Model


class CharacterModel(Model):
    """Model for predicting characters from images."""

    def __init__(
        self,
        network_fn: Type[nn.Module],
        dataset: Type[Dataset],
        network_args: Optional[Dict] = None,
        dataset_args: Optional[Dict] = None,
        metrics: Optional[Dict] = None,
        criterion: Optional[Callable] = None,
        criterion_args: Optional[Dict] = None,
        optimizer: Optional[Callable] = None,
        optimizer_args: Optional[Dict] = None,
        lr_scheduler: Optional[Callable] = None,
        lr_scheduler_args: Optional[Dict] = None,
        swa_args: Optional[Dict] = None,
        device: Optional[str] = None,
    ) -> None:
        """Initializes the CharacterModel."""

        super().__init__(
            network_fn,
            dataset,
            network_args,
            dataset_args,
            metrics,
            criterion,
            criterion_args,
            optimizer,
            optimizer_args,
            lr_scheduler,
            lr_scheduler_args,
            swa_args,
            device,
        )
        if self._mapper is None:
            self._mapper = EmnistMapper()
        self.tensor_transform = ToTensor()
        self.softmax = nn.Softmax(dim=0)

    @torch.no_grad()
    def predict_on_image(
        self, image: Union[np.ndarray, torch.Tensor]
    ) -> Tuple[str, float]:
        """Character prediction on an image.

        Args:
            image (Union[np.ndarray, torch.Tensor]): An image containing a character.

        Returns:
            Tuple[str, float]: The predicted character and the confidence in the prediction.

        """

        if image.dtype == np.uint8:
            # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
            image = self.tensor_transform(image)
        if image.dtype == torch.uint8:
            # If the image is an unscaled tensor.
            image = image.type("torch.FloatTensor") / 255

        # Put the image tensor on the device the model weights are on.
        image = image.to(self.device)
        logits = self.forward(image)

        prediction = self.softmax(logits.squeeze(0))

        index = int(torch.argmax(prediction, dim=0))
        confidence_of_prediction = prediction[index]
        predicted_character = self.mapper(index)

        return predicted_character, confidence_of_prediction