summaryrefslogtreecommitdiff
path: root/text_recognizer/models/vqvae_model.py
blob: 70f6f1f73538043e4a38f8730ceb3fa124813d2b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Defines the VQVAEModel class."""
from typing import Callable, Dict, Optional, Tuple, Type, Union

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor

from text_recognizer.datasets import EmnistMapper
from text_recognizer.models.base import Model


class VQVAEModel(Model):
    """Model for reconstructing images from codebook."""

    def __init__(
        self,
        network_fn: Type[nn.Module],
        dataset: Type[Dataset],
        network_args: Optional[Dict] = None,
        dataset_args: Optional[Dict] = None,
        metrics: Optional[Dict] = None,
        criterion: Optional[Callable] = None,
        criterion_args: Optional[Dict] = None,
        optimizer: Optional[Callable] = None,
        optimizer_args: Optional[Dict] = None,
        lr_scheduler: Optional[Callable] = None,
        lr_scheduler_args: Optional[Dict] = None,
        swa_args: Optional[Dict] = None,
        device: Optional[str] = None,
    ) -> None:
        """Initializes the CharacterModel."""

        super().__init__(
            network_fn,
            dataset,
            network_args,
            dataset_args,
            metrics,
            criterion,
            criterion_args,
            optimizer,
            optimizer_args,
            lr_scheduler,
            lr_scheduler_args,
            swa_args,
            device,
        )
        self.pad_token = dataset_args["args"]["pad_token"]
        if self._mapper is None:
            self._mapper = EmnistMapper(pad_token=self.pad_token,)
        self.tensor_transform = ToTensor()
        self.softmax = nn.Softmax(dim=0)

    @torch.no_grad()
    def predict_on_image(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
        """Reconstruction of image.

        Args:
            image (Union[np.ndarray, torch.Tensor]): An image containing a character.

        Returns:
            Tuple[str, float]: The predicted character and the confidence in the prediction.

        """
        self.eval()

        if image.dtype == np.uint8:
            # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
            image = self.tensor_transform(image)
        if image.dtype == torch.uint8:
            # If the image is an unscaled tensor.
            image = image.type("torch.FloatTensor") / 255

        # Put the image tensor on the device the model weights are on.
        image = image.to(self.device)
        image_reconstructed, _ = self.forward(image)

        return image_reconstructed