summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r--src/text_recognizer/models/__init__.py2
-rw-r--r--src/text_recognizer/models/base.py11
-rw-r--r--src/text_recognizer/models/segmentation_model.py75
-rw-r--r--src/text_recognizer/models/transformer_model.py4
4 files changed, 85 insertions, 7 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index bf89404..a645cec 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -2,11 +2,13 @@
from .base import Model
from .character_model import CharacterModel
from .crnn_model import CRNNModel
+from .segmentation_model import SegmentationModel
from .transformer_model import TransformerModel
__all__ = [
"CharacterModel",
"CRNNModel",
"Model",
+ "SegmentationModel",
"TransformerModel",
]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index d394b4c..f2cd4b8 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -159,7 +159,7 @@ class Model(ABC):
self.test_dataset = self.dataset(train=False, **self.dataset_args["args"])
self.test_dataset.load_or_generate_data()
- # Set the flag to true to disable ability to load data agian.
+ # Set the flag to true to disable ability to load data again.
self.data_prepared = True
def train_dataloader(self) -> DataLoader:
@@ -260,7 +260,7 @@ class Model(ABC):
@property
def mapping(self) -> Dict:
"""Returns the mapping between network output and Emnist character."""
- return self._mapper.mapping
+ return self._mapper.mapping if self._mapper is not None else None
def eval(self) -> None:
"""Sets the network to evaluation mode."""
@@ -341,7 +341,7 @@ class Model(ABC):
if input_shape is not None:
summary(self.network, input_shape, depth=depth, device=device)
elif self._input_shape is not None:
- input_shape = (1,) + tuple(self._input_shape)
+ input_shape = tuple(self._input_shape)
summary(self.network, input_shape, depth=depth, device=device)
else:
logger.warning("Could not print summary as input shape is not set.")
@@ -427,7 +427,7 @@ class Model(ABC):
)
shutil.copyfile(filepath, str(checkpoint_path / "best.pt"))
- def load_weights(self, network_fn: Type[nn.Module]) -> None:
+ def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None:
"""Load the network weights."""
logger.debug("Loading network with pretrained weights.")
filename = glob(self.weights_filename)[0]
@@ -441,7 +441,8 @@ class Model(ABC):
weights = state_dict["model_state"]
# Initializes the network with trained weights.
- self._network = network_fn(**self._network_args)
+ if network_fn is not None:
+ self._network = network_fn(**self._network_args)
self._network.load_state_dict(weights)
if "swa_network" in state_dict:
diff --git a/src/text_recognizer/models/segmentation_model.py b/src/text_recognizer/models/segmentation_model.py
new file mode 100644
index 0000000..613108a
--- /dev/null
+++ b/src/text_recognizer/models/segmentation_model.py
@@ -0,0 +1,75 @@
+"""Segmentation model for detecting and segmenting lines."""
+from typing import Callable, Dict, Optional, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.models.base import Model
+
+
+class SegmentationModel(Model):
+ """Model for segmenting lines in an image."""
+
+ def __init__(
+ self,
+ network_fn: str,
+ dataset: str,
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.tensor_transform = ToTensor()
+ self.softmax = nn.Softmax(dim=2)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype is np.uint8:
+ # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype is torch.uint8 or image.dtype is torch.int64:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ if not torch.is_tensor(image):
+ image = Tensor(image)
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+
+ logits = self.forward(image)
+
+ segmentation_mask = torch.argmax(logits, dim=1)
+
+ return segmentation_mask
diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py
index 968a047..a912122 100644
--- a/src/text_recognizer/models/transformer_model.py
+++ b/src/text_recognizer/models/transformer_model.py
@@ -18,8 +18,8 @@ class TransformerModel(Model):
def __init__(
self,
- network_fn: Type[nn.Module],
- dataset: Type[Dataset],
+ network_fn: str,
+ dataset: str,
network_args: Optional[Dict] = None,
dataset_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,