summaryrefslogtreecommitdiff
path: root/text_recognizer/models/segmentation_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/segmentation_model.py')
-rw-r--r--text_recognizer/models/segmentation_model.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/text_recognizer/models/segmentation_model.py b/text_recognizer/models/segmentation_model.py
new file mode 100644
index 0000000..613108a
--- /dev/null
+++ b/text_recognizer/models/segmentation_model.py
@@ -0,0 +1,75 @@
+"""Segmentation model for detecting and segmenting lines."""
+from typing import Callable, Dict, Optional, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.models.base import Model
+
+
+class SegmentationModel(Model):
+ """Model for segmenting lines in an image."""
+
+ def __init__(
+ self,
+ network_fn: str,
+ dataset: str,
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.tensor_transform = ToTensor()
+ self.softmax = nn.Softmax(dim=2)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype is np.uint8:
+ # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype is torch.uint8 or image.dtype is torch.int64:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ if not torch.is_tensor(image):
+ image = Tensor(image)
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+
+ logits = self.forward(image)
+
+ segmentation_mask = torch.argmax(logits, dim=1)
+
+ return segmentation_mask