summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/segmentation_model.py
blob: 613108a8143f9d0b03362d48f5b66dc0d2f36bcb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Segmentation model for detecting and segmenting lines."""
from typing import Callable, Dict, Optional, Type, Union

import numpy as np
import torch
from torch import nn
from torch import Tensor
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor

from text_recognizer.models.base import Model


class SegmentationModel(Model):
    """Model for segmenting lines in an image."""

    def __init__(
        self,
        network_fn: str,
        dataset: str,
        network_args: Optional[Dict] = None,
        dataset_args: Optional[Dict] = None,
        metrics: Optional[Dict] = None,
        criterion: Optional[Callable] = None,
        criterion_args: Optional[Dict] = None,
        optimizer: Optional[Callable] = None,
        optimizer_args: Optional[Dict] = None,
        lr_scheduler: Optional[Callable] = None,
        lr_scheduler_args: Optional[Dict] = None,
        swa_args: Optional[Dict] = None,
        device: Optional[str] = None,
    ) -> None:
        super().__init__(
            network_fn,
            dataset,
            network_args,
            dataset_args,
            metrics,
            criterion,
            criterion_args,
            optimizer,
            optimizer_args,
            lr_scheduler,
            lr_scheduler_args,
            swa_args,
            device,
        )
        self.tensor_transform = ToTensor()
        self.softmax = nn.Softmax(dim=2)

    @torch.no_grad()
    def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor:
        """Predict on a single input."""
        self.eval()

        if image.dtype is np.uint8:
            # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1].
            image = self.tensor_transform(image)

        # Rescale image between 0 and 1.
        if image.dtype is torch.uint8 or image.dtype is torch.int64:
            # If the image is an unscaled tensor.
            image = image.type("torch.FloatTensor") / 255

        if not torch.is_tensor(image):
            image = Tensor(image)

        # Put the image tensor on the device the model weights are on.
        image = image.to(self.device)

        logits = self.forward(image)

        segmentation_mask = torch.argmax(logits, dim=1)

        return segmentation_mask