From 25b5d6983d51e0e791b96a76beb7e49f392cd9a8 Mon Sep 17 00:00:00 2001
From: aktersnurra <gustaf.rydholm@gmail.com>
Date: Mon, 7 Dec 2020 22:54:04 +0100
Subject: Segmentation working!

---
 src/text_recognizer/models/__init__.py           |  2 +
 src/text_recognizer/models/base.py               | 11 ++--
 src/text_recognizer/models/segmentation_model.py | 75 ++++++++++++++++++++++++
 src/text_recognizer/models/transformer_model.py  |  4 +-
 4 files changed, 85 insertions(+), 7 deletions(-)
 create mode 100644 src/text_recognizer/models/segmentation_model.py

(limited to 'src/text_recognizer/models')

diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index bf89404..a645cec 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -2,11 +2,13 @@
 from .base import Model
 from .character_model import CharacterModel
 from .crnn_model import CRNNModel
+from .segmentation_model import SegmentationModel
 from .transformer_model import TransformerModel
 
 __all__ = [
     "CharacterModel",
     "CRNNModel",
     "Model",
+    "SegmentationModel",
     "TransformerModel",
 ]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index d394b4c..f2cd4b8 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -159,7 +159,7 @@ class Model(ABC):
             self.test_dataset = self.dataset(train=False, **self.dataset_args["args"])
             self.test_dataset.load_or_generate_data()
 
-            # Set the flag to true to disable ability to load data agian.
+            # Set the flag to true to disable ability to load data again.
             self.data_prepared = True
 
     def train_dataloader(self) -> DataLoader:
@@ -260,7 +260,7 @@ class Model(ABC):
     @property
     def mapping(self) -> Dict:
         """Returns the mapping between network output and Emnist character."""
-        return self._mapper.mapping
+        return self._mapper.mapping if self._mapper is not None else None
 
     def eval(self) -> None:
         """Sets the network to evaluation mode."""
@@ -341,7 +341,7 @@ class Model(ABC):
         if input_shape is not None:
             summary(self.network, input_shape, depth=depth, device=device)
         elif self._input_shape is not None:
-            input_shape = (1,) + tuple(self._input_shape)
+            input_shape = tuple(self._input_shape)
             summary(self.network, input_shape, depth=depth, device=device)
         else:
             logger.warning("Could not print summary as input shape is not set.")
@@ -427,7 +427,7 @@ class Model(ABC):
             )
             shutil.copyfile(filepath, str(checkpoint_path / "best.pt"))
 
-    def load_weights(self, network_fn: Type[nn.Module]) -> None:
+    def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None:
         """Load the network weights."""
         logger.debug("Loading network with pretrained weights.")
         filename = glob(self.weights_filename)[0]
@@ -441,7 +441,8 @@ class Model(ABC):
         weights = state_dict["model_state"]
 
         # Initializes the network with trained weights.
-        self._network = network_fn(**self._network_args)
+        if network_fn is not None:
+            self._network = network_fn(**self._network_args)
         self._network.load_state_dict(weights)
 
         if "swa_network" in state_dict:
diff --git a/src/text_recognizer/models/segmentation_model.py b/src/text_recognizer/models/segmentation_model.py
new file mode 100644
index 0000000..613108a
--- /dev/null
+++ b/src/text_recognizer/models/segmentation_model.py
@@ -0,0 +1,75 @@
+"""Segmentation model for detecting and segmenting lines."""
+from typing import Callable, Dict, Optional, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.models.base import Model
+
+
+class SegmentationModel(Model):
+    """Model for segmenting lines in an image."""
+
+    def __init__(
+        self,
+        network_fn: str,
+        dataset: str,
+        network_args: Optional[Dict] = None,
+        dataset_args: Optional[Dict] = None,
+        metrics: Optional[Dict] = None,
+        criterion: Optional[Callable] = None,
+        criterion_args: Optional[Dict] = None,
+        optimizer: Optional[Callable] = None,
+        optimizer_args: Optional[Dict] = None,
+        lr_scheduler: Optional[Callable] = None,
+        lr_scheduler_args: Optional[Dict] = None,
+        swa_args: Optional[Dict] = None,
+        device: Optional[str] = None,
+    ) -> None:
+        super().__init__(
+            network_fn,
+            dataset,
+            network_args,
+            dataset_args,
+            metrics,
+            criterion,
+            criterion_args,
+            optimizer,
+            optimizer_args,
+            lr_scheduler,
+            lr_scheduler_args,
+            swa_args,
+            device,
+        )
+        self.tensor_transform = ToTensor()
+        self.softmax = nn.Softmax(dim=2)
+
+    @torch.no_grad()
+    def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor:
+        """Predict on a single input."""
+        self.eval()
+
+        if image.dtype is np.uint8:
+            # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1].
+            image = self.tensor_transform(image)
+
+        # Rescale image between 0 and 1.
+        if image.dtype is torch.uint8 or image.dtype is torch.int64:
+            # If the image is an unscaled tensor.
+            image = image.type("torch.FloatTensor") / 255
+
+        if not torch.is_tensor(image):
+            image = Tensor(image)
+
+        # Put the image tensor on the device the model weights are on.
+        image = image.to(self.device)
+
+        logits = self.forward(image)
+
+        segmentation_mask = torch.argmax(logits, dim=1)
+
+        return segmentation_mask
diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py
index 968a047..a912122 100644
--- a/src/text_recognizer/models/transformer_model.py
+++ b/src/text_recognizer/models/transformer_model.py
@@ -18,8 +18,8 @@ class TransformerModel(Model):
 
     def __init__(
         self,
-        network_fn: Type[nn.Module],
-        dataset: Type[Dataset],
+        network_fn: str,
+        dataset: str,
         network_args: Optional[Dict] = None,
         dataset_args: Optional[Dict] = None,
         metrics: Optional[Dict] = None,
-- 
cgit v1.2.3-70-g09d2