summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r--src/text_recognizer/models/__init__.py2
-rw-r--r--src/text_recognizer/models/base.py2
-rw-r--r--src/text_recognizer/models/transformer_model.py12
-rw-r--r--src/text_recognizer/models/vqvae_model.py80
4 files changed, 92 insertions, 4 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index eb5dbce..7647d7e 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -5,6 +5,7 @@ from .crnn_model import CRNNModel
from .ctc_transformer_model import CTCTransformerModel
from .segmentation_model import SegmentationModel
from .transformer_model import TransformerModel
+from .vqvae_model import VQVAEModel
__all__ = [
"CharacterModel",
@@ -13,4 +14,5 @@ __all__ = [
"Model",
"SegmentationModel",
"TransformerModel",
+ "VQVAEModel",
]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index f2cd4b8..70f4cdb 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -332,7 +332,7 @@ class Model(ABC):
def summary(
self,
input_shape: Optional[Union[List, Tuple]] = None,
- depth: int = 4,
+ depth: int = 3,
device: Optional[str] = None,
) -> None:
"""Prints a summary of the network architecture."""
diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py
index 12e497f..3f63053 100644
--- a/src/text_recognizer/models/transformer_model.py
+++ b/src/text_recognizer/models/transformer_model.py
@@ -6,9 +6,9 @@ import torch
from torch import nn
from torch import Tensor
from torch.utils.data import Dataset
-from torchvision.transforms import ToTensor
from text_recognizer.datasets import EmnistMapper
+import text_recognizer.datasets.transforms as transforms
from text_recognizer.models.base import Model
from text_recognizer.networks import greedy_decoder
@@ -60,13 +60,19 @@ class TransformerModel(Model):
eos_token=self.eos_token,
lower=self.lower,
)
- self.tensor_transform = ToTensor()
-
+ self.tensor_transform = transforms.Compose(
+ [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])]
+ )
self.softmax = nn.Softmax(dim=2)
@torch.no_grad()
def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
src = self.network.extract_image_features(image)
+
+ # Added for vqvae transformer.
+ if isinstance(src, Tuple):
+ src = src[0]
+
memory = self.network.encoder(src)
confidence_of_predictions = []
diff --git a/src/text_recognizer/models/vqvae_model.py b/src/text_recognizer/models/vqvae_model.py
new file mode 100644
index 0000000..70f6f1f
--- /dev/null
+++ b/src/text_recognizer/models/vqvae_model.py
@@ -0,0 +1,80 @@
+"""Defines the VQVAEModel class."""
+from typing import Callable, Dict, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+
+
+class VQVAEModel(Model):
+ """Model for reconstructing images from codebook."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ """Initializes the CharacterModel."""
+
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.pad_token = dataset_args["args"]["pad_token"]
+ if self._mapper is None:
+ self._mapper = EmnistMapper(pad_token=self.pad_token,)
+ self.tensor_transform = ToTensor()
+ self.softmax = nn.Softmax(dim=0)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ """Reconstruction of image.
+
+ Args:
+ image (Union[np.ndarray, torch.Tensor]): An image containing a character.
+
+ Returns:
+ Tuple[str, float]: The predicted character and the confidence in the prediction.
+
+ """
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ image_reconstructed, _ = self.forward(image)
+
+ return image_reconstructed