summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/character_model.py
blob: fd69bf25bacf62bb2ca2f2ca81db0f52b1cb3073 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Defines the CharacterModel class."""
from typing import Callable, Dict, Optional, Tuple

import numpy as np
import torch
from torch import nn
from torchvision.transforms import ToTensor

from text_recognizer.datasets.emnist_dataset import load_emnist_mapping
from text_recognizer.models.base import Model
from text_recognizer.networks.mlp import mlp


class CharacterModel(Model):
    """Model for predicting characters from images."""

    def __init__(
        self,
        network_fn: Callable,
        network_args: Dict,
        data_loader_args: Optional[Dict] = None,
        metrics: Optional[Dict] = None,
        criterion: Optional[Callable] = None,
        criterion_args: Optional[Dict] = None,
        optimizer: Optional[Callable] = None,
        optimizer_args: Optional[Dict] = None,
        lr_scheduler: Optional[Callable] = None,
        lr_scheduler_args: Optional[Dict] = None,
        device: Optional[str] = None,
    ) -> None:
        """Initializes the CharacterModel."""

        super().__init__(
            network_fn,
            network_args,
            data_loader_args,
            metrics,
            criterion,
            criterion_args,
            optimizer,
            optimizer_args,
            lr_scheduler,
            lr_scheduler_args,
            device,
        )
        self.emnist_mapping = self.mapping()
        self.eval()

    def mapping(self) -> Dict[int, str]:
        """Mapping between integers and classes."""
        mapping = load_emnist_mapping()
        return mapping

    def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]:
        """Character prediction on an image.

        Args:
            image (np.ndarray): An image containing a character.

        Returns:
            Tuple[str, float]: The predicted character and the confidence in the prediction.

        """
        if image.dtype == np.uint8:
            image = (image / 255).astype(np.float32)

        # Conver to Pytorch Tensor.
        image = ToTensor(image)

        prediction = self.network(image)
        index = torch.argmax(prediction, dim=1)
        confidence_of_prediction = prediction[index]
        predicted_character = self.emnist_mapping[index]

        return predicted_character, confidence_of_prediction