diff options
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r-- | src/text_recognizer/models/__init__.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/models/base.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/models/transformer_model.py | 12 | ||||
-rw-r--r-- | src/text_recognizer/models/vqvae_model.py | 80 |
4 files changed, 92 insertions, 4 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index eb5dbce..7647d7e 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -5,6 +5,7 @@ from .crnn_model import CRNNModel from .ctc_transformer_model import CTCTransformerModel from .segmentation_model import SegmentationModel from .transformer_model import TransformerModel +from .vqvae_model import VQVAEModel __all__ = [ "CharacterModel", @@ -13,4 +14,5 @@ __all__ = [ "Model", "SegmentationModel", "TransformerModel", + "VQVAEModel", ] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index f2cd4b8..70f4cdb 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -332,7 +332,7 @@ class Model(ABC): def summary( self, input_shape: Optional[Union[List, Tuple]] = None, - depth: int = 4, + depth: int = 3, device: Optional[str] = None, ) -> None: """Prints a summary of the network architecture.""" diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py index 12e497f..3f63053 100644 --- a/src/text_recognizer/models/transformer_model.py +++ b/src/text_recognizer/models/transformer_model.py @@ -6,9 +6,9 @@ import torch from torch import nn from torch import Tensor from torch.utils.data import Dataset -from torchvision.transforms import ToTensor from text_recognizer.datasets import EmnistMapper +import text_recognizer.datasets.transforms as transforms from text_recognizer.models.base import Model from text_recognizer.networks import greedy_decoder @@ -60,13 +60,19 @@ class TransformerModel(Model): eos_token=self.eos_token, lower=self.lower, ) - self.tensor_transform = ToTensor() - + self.tensor_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])] + ) self.softmax = nn.Softmax(dim=2) @torch.no_grad() def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: src = self.network.extract_image_features(image) + + # Added for vqvae transformer. + if isinstance(src, Tuple): + src = src[0] + memory = self.network.encoder(src) confidence_of_predictions = [] diff --git a/src/text_recognizer/models/vqvae_model.py b/src/text_recognizer/models/vqvae_model.py new file mode 100644 index 0000000..70f6f1f --- /dev/null +++ b/src/text_recognizer/models/vqvae_model.py @@ -0,0 +1,80 @@ +"""Defines the VQVAEModel class.""" +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model + + +class VQVAEModel(Model): + """Model for reconstructing images from codebook.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + """Initializes the CharacterModel.""" + + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + self.pad_token = dataset_args["args"]["pad_token"] + if self._mapper is None: + self._mapper = EmnistMapper(pad_token=self.pad_token,) + self.tensor_transform = ToTensor() + self.softmax = nn.Softmax(dim=0) + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + """Reconstruction of image. + + Args: + image (Union[np.ndarray, torch.Tensor]): An image containing a character. + + Returns: + Tuple[str, float]: The predicted character and the confidence in the prediction. + + """ + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + image_reconstructed, _ = self.forward(image) + + return image_reconstructed |