diff options
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r-- | src/text_recognizer/models/__init__.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/models/base.py | 11 | ||||
-rw-r--r-- | src/text_recognizer/models/segmentation_model.py | 75 | ||||
-rw-r--r-- | src/text_recognizer/models/transformer_model.py | 4 |
4 files changed, 85 insertions, 7 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index bf89404..a645cec 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -2,11 +2,13 @@ from .base import Model from .character_model import CharacterModel from .crnn_model import CRNNModel +from .segmentation_model import SegmentationModel from .transformer_model import TransformerModel __all__ = [ "CharacterModel", "CRNNModel", "Model", + "SegmentationModel", "TransformerModel", ] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index d394b4c..f2cd4b8 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -159,7 +159,7 @@ class Model(ABC): self.test_dataset = self.dataset(train=False, **self.dataset_args["args"]) self.test_dataset.load_or_generate_data() - # Set the flag to true to disable ability to load data agian. + # Set the flag to true to disable ability to load data again. self.data_prepared = True def train_dataloader(self) -> DataLoader: @@ -260,7 +260,7 @@ class Model(ABC): @property def mapping(self) -> Dict: """Returns the mapping between network output and Emnist character.""" - return self._mapper.mapping + return self._mapper.mapping if self._mapper is not None else None def eval(self) -> None: """Sets the network to evaluation mode.""" @@ -341,7 +341,7 @@ class Model(ABC): if input_shape is not None: summary(self.network, input_shape, depth=depth, device=device) elif self._input_shape is not None: - input_shape = (1,) + tuple(self._input_shape) + input_shape = tuple(self._input_shape) summary(self.network, input_shape, depth=depth, device=device) else: logger.warning("Could not print summary as input shape is not set.") @@ -427,7 +427,7 @@ class Model(ABC): ) shutil.copyfile(filepath, str(checkpoint_path / "best.pt")) - def load_weights(self, network_fn: Type[nn.Module]) -> None: + def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None: """Load the network weights.""" logger.debug("Loading network with pretrained weights.") filename = glob(self.weights_filename)[0] @@ -441,7 +441,8 @@ class Model(ABC): weights = state_dict["model_state"] # Initializes the network with trained weights. - self._network = network_fn(**self._network_args) + if network_fn is not None: + self._network = network_fn(**self._network_args) self._network.load_state_dict(weights) if "swa_network" in state_dict: diff --git a/src/text_recognizer/models/segmentation_model.py b/src/text_recognizer/models/segmentation_model.py new file mode 100644 index 0000000..613108a --- /dev/null +++ b/src/text_recognizer/models/segmentation_model.py @@ -0,0 +1,75 @@ +"""Segmentation model for detecting and segmenting lines.""" +from typing import Callable, Dict, Optional, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.models.base import Model + + +class SegmentationModel(Model): + """Model for segmenting lines in an image.""" + + def __init__( + self, + network_fn: str, + dataset: str, + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + self.tensor_transform = ToTensor() + self.softmax = nn.Softmax(dim=2) + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor: + """Predict on a single input.""" + self.eval() + + if image.dtype is np.uint8: + # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype is torch.uint8 or image.dtype is torch.int64: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + if not torch.is_tensor(image): + image = Tensor(image) + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + + logits = self.forward(image) + + segmentation_mask = torch.argmax(logits, dim=1) + + return segmentation_mask diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py index 968a047..a912122 100644 --- a/src/text_recognizer/models/transformer_model.py +++ b/src/text_recognizer/models/transformer_model.py @@ -18,8 +18,8 @@ class TransformerModel(Model): def __init__( self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], + network_fn: str, + dataset: str, network_args: Optional[Dict] = None, dataset_args: Optional[Dict] = None, metrics: Optional[Dict] = None, |