1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
|
"""Segmentation model for detecting and segmenting lines."""
from typing import Callable, Dict, Optional, Type, Union
import numpy as np
import torch
from torch import nn
from torch import Tensor
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from text_recognizer.models.base import Model
class SegmentationModel(Model):
"""Model for segmenting lines in an image."""
def __init__(
self,
network_fn: str,
dataset: str,
network_args: Optional[Dict] = None,
dataset_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
criterion_args: Optional[Dict] = None,
optimizer: Optional[Callable] = None,
optimizer_args: Optional[Dict] = None,
lr_scheduler: Optional[Callable] = None,
lr_scheduler_args: Optional[Dict] = None,
swa_args: Optional[Dict] = None,
device: Optional[str] = None,
) -> None:
super().__init__(
network_fn,
dataset,
network_args,
dataset_args,
metrics,
criterion,
criterion_args,
optimizer,
optimizer_args,
lr_scheduler,
lr_scheduler_args,
swa_args,
device,
)
self.tensor_transform = ToTensor()
self.softmax = nn.Softmax(dim=2)
@torch.no_grad()
def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor:
"""Predict on a single input."""
self.eval()
if image.dtype is np.uint8:
# Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1].
image = self.tensor_transform(image)
# Rescale image between 0 and 1.
if image.dtype is torch.uint8 or image.dtype is torch.int64:
# If the image is an unscaled tensor.
image = image.type("torch.FloatTensor") / 255
if not torch.is_tensor(image):
image = Tensor(image)
# Put the image tensor on the device the model weights are on.
image = image.to(self.device)
logits = self.forward(image)
segmentation_mask = torch.argmax(logits, dim=1)
return segmentation_mask
|