summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r--src/text_recognizer/models/__init__.py2
-rw-r--r--src/text_recognizer/models/ctc_transformer_model.py120
-rw-r--r--src/text_recognizer/models/transformer_model.py4
3 files changed, 125 insertions, 1 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index a645cec..eb5dbce 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -2,12 +2,14 @@
from .base import Model
from .character_model import CharacterModel
from .crnn_model import CRNNModel
+from .ctc_transformer_model import CTCTransformerModel
from .segmentation_model import SegmentationModel
from .transformer_model import TransformerModel
__all__ = [
"CharacterModel",
"CRNNModel",
+ "CTCTransformerModel",
"Model",
"SegmentationModel",
"TransformerModel",
diff --git a/src/text_recognizer/models/ctc_transformer_model.py b/src/text_recognizer/models/ctc_transformer_model.py
new file mode 100644
index 0000000..25925f2
--- /dev/null
+++ b/src/text_recognizer/models/ctc_transformer_model.py
@@ -0,0 +1,120 @@
+"""Defines the CTC Transformer Model class."""
+from typing import Callable, Dict, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+from text_recognizer.networks import greedy_decoder
+
+
+class CTCTransformerModel(Model):
+ """Model for predicting a sequence of characters from an image of a text line."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.pad_token = dataset_args["args"]["pad_token"]
+ self.lower = dataset_args["args"]["lower"]
+
+ if self._mapper is None:
+ self._mapper = EmnistMapper(pad_token=self.pad_token, lower=self.lower,)
+
+ self.tensor_transform = ToTensor()
+
+ def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
+ """Computes the CTC loss.
+
+ Args:
+ output (Tensor): Model predictions.
+ targets (Tensor): Correct output sequence.
+
+ Returns:
+ Tensor: The CTC loss.
+
+ """
+ # Input lengths on the form [T, B]
+ input_lengths = torch.full(
+ size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,
+ )
+
+ # Configure target tensors for ctc loss.
+ targets_ = Tensor([]).to(self.device)
+ target_lengths = []
+ for t in targets:
+ # Remove padding symbol as it acts as the blank symbol.
+ t = t[t < 53]
+ targets_ = torch.cat([targets_, t])
+ target_lengths.append(len(t))
+
+ targets = targets_.type(dtype=torch.long)
+ target_lengths = (
+ torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
+ )
+
+ return self._criterion(output, targets, input_lengths, target_lengths)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ log_probs = self.forward(image)
+
+ raw_pred, _ = greedy_decoder(
+ predictions=log_probs,
+ character_mapper=self.mapper,
+ blank_label=53,
+ collapse_repeated=True,
+ )
+
+ log_probs, _ = log_probs.max(dim=2)
+
+ predicted_characters = "".join(raw_pred[0])
+ confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item()
+
+ return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py
index a912122..12e497f 100644
--- a/src/text_recognizer/models/transformer_model.py
+++ b/src/text_recognizer/models/transformer_model.py
@@ -50,13 +50,15 @@ class TransformerModel(Model):
self.init_token = dataset_args["args"]["init_token"]
self.pad_token = dataset_args["args"]["pad_token"]
self.eos_token = dataset_args["args"]["eos_token"]
- self.max_len = 120
+ self.lower = dataset_args["args"]["lower"]
+ self.max_len = 100
if self._mapper is None:
self._mapper = EmnistMapper(
init_token=self.init_token,
pad_token=self.pad_token,
eos_token=self.eos_token,
+ lower=self.lower,
)
self.tensor_transform = ToTensor()