summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r--src/text_recognizer/models/__init__.py11
-rw-r--r--src/text_recognizer/models/base.py55
-rw-r--r--src/text_recognizer/models/character_model.py1
-rw-r--r--src/text_recognizer/models/line_ctc_model.py8
-rw-r--r--src/text_recognizer/models/vision_transformer_model.py117
5 files changed, 175 insertions, 17 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index a3cfc15..0855079 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -3,5 +3,14 @@ from .base import Model
from .character_model import CharacterModel
from .line_ctc_model import LineCTCModel
from .metrics import accuracy, cer, wer
+from .vision_transformer_model import VisionTransformerModel
-__all__ = ["Model", "cer", "CharacterModel", "LineCTCModel", "accuracy", "wer"]
+__all__ = [
+ "Model",
+ "cer",
+ "CharacterModel",
+ "CNNTransfromerModel",
+ "LineCTCModel",
+ "accuracy",
+ "wer",
+]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index e89b670..cbef787 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -6,7 +6,7 @@ import importlib
from pathlib import Path
import re
import shutil
-from typing import Callable, Dict, Optional, Tuple, Type
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from loguru import logger
import torch
@@ -15,6 +15,7 @@ from torch import Tensor
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.utils.data import DataLoader, Dataset, random_split
from torchsummary import summary
+from torchvision.transforms import Compose
from text_recognizer.datasets import EmnistMapper
@@ -128,16 +129,41 @@ class Model(ABC):
self._configure_criterion()
self._configure_optimizers()
- # Prints a summary of the network in terminal.
- self.summary()
-
# Set this flag to true to prevent the model from configuring again.
self.is_configured = True
+ def _configure_transforms(self) -> None:
+ # Load transforms.
+ transforms_module = importlib.import_module(
+ "text_recognizer.datasets.transforms"
+ )
+ if (
+ "transform" in self.dataset_args["args"]
+ and self.dataset_args["args"]["transform"] is not None
+ ):
+ transform_ = [
+ getattr(transforms_module, t["type"])()
+ for t in self.dataset_args["args"]["transform"]
+ ]
+ self.dataset_args["args"]["transform"] = Compose(transform_)
+ if (
+ "target_transform" in self.dataset_args["args"]
+ and self.dataset_args["args"]["target_transform"] is not None
+ ):
+ target_transform_ = [
+ torch.tensor,
+ ]
+ for t in self.dataset_args["args"]["target_transform"]:
+ args = t["args"] or {}
+ target_transform_.append(getattr(transforms_module, t["type"])(**args))
+ self.dataset_args["args"]["target_transform"] = Compose(target_transform_)
+
def prepare_data(self) -> None:
"""Prepare data for training."""
# TODO add downloading.
if not self.data_prepared:
+ self._configure_transforms()
+
# Load train dataset.
train_dataset = self.dataset(train=True, **self.dataset_args["args"])
train_dataset.load_or_generate_data()
@@ -327,20 +353,20 @@ class Model(ABC):
else:
return self.network(x)
- def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
- """Compute the loss."""
- return self.criterion(output, targets)
-
def summary(
- self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 3
+ self,
+ input_shape: Optional[Union[List, Tuple]] = None,
+ depth: int = 4,
+ device: Optional[str] = None,
) -> None:
"""Prints a summary of the network architecture."""
+ device = self.device if device is None else device
if input_shape is not None:
- summary(self.network, input_shape, depth=depth, device=self.device)
+ summary(self.network, input_shape, depth=depth, device=device)
elif self._input_shape is not None:
input_shape = (1,) + tuple(self._input_shape)
- summary(self.network, input_shape, depth=depth, device=self.device)
+ summary(self.network, input_shape, depth=depth, device=device)
else:
logger.warning("Could not print summary as input shape is not set.")
@@ -364,18 +390,21 @@ class Model(ABC):
return state
- def load_from_checkpoint(self, checkpoint_path: Path) -> None:
+ def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None:
"""Load a previously saved checkpoint.
Args:
checkpoint_path (Path): Path to the experiment with the checkpoint.
"""
+ checkpoint_path = Path(checkpoint_path)
+ self.prepare_data()
+ self.configure_model()
logger.debug("Loading checkpoint...")
if not checkpoint_path.exists():
logger.debug("File does not exist {str(checkpoint_path)}")
- checkpoint = torch.load(str(checkpoint_path))
+ checkpoint = torch.load(str(checkpoint_path), map_location=self.device)
self._network.load_state_dict(checkpoint["model_state"])
if self._optimizer is not None:
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 50e94a2..3cf6695 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -65,6 +65,7 @@ class CharacterModel(Model):
Tuple[str, float]: The predicted character and the confidence in the prediction.
"""
+ self.eval()
if image.dtype == np.uint8:
# Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py
index 16eaed3..cdc2d8b 100644
--- a/src/text_recognizer/models/line_ctc_model.py
+++ b/src/text_recognizer/models/line_ctc_model.py
@@ -51,7 +51,7 @@ class LineCTCModel(Model):
self._mapper = EmnistMapper()
self.tensor_transform = ToTensor()
- def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
+ def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
"""Computes the CTC loss.
Args:
@@ -82,11 +82,13 @@ class LineCTCModel(Model):
torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
)
- return self.criterion(output, targets, input_lengths, target_lengths)
+ return self._criterion(output, targets, input_lengths, target_lengths)
@torch.no_grad()
def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
"""Predict on a single input."""
+ self.eval()
+
if image.dtype == np.uint8:
# Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
image = self.tensor_transform(image)
@@ -110,6 +112,6 @@ class LineCTCModel(Model):
log_probs, _ = log_probs.max(dim=2)
predicted_characters = "".join(raw_pred[0])
- confidence_of_prediction = torch.exp(log_probs.sum()).item()
+ confidence_of_prediction = torch.exp(-log_probs.sum()).item()
return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/vision_transformer_model.py b/src/text_recognizer/models/vision_transformer_model.py
new file mode 100644
index 0000000..20bd4ca
--- /dev/null
+++ b/src/text_recognizer/models/vision_transformer_model.py
@@ -0,0 +1,117 @@
+"""Defines the CNN-Transformer class."""
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+from text_recognizer.networks import greedy_decoder
+
+
+class VisionTransformerModel(Model):
+ """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.init_token = dataset_args["args"]["init_token"]
+ self.pad_token = dataset_args["args"]["pad_token"]
+ self.eos_token = dataset_args["args"]["eos_token"]
+ if network_args is not None:
+ self.max_len = network_args["max_len"]
+ else:
+ self.max_len = 128
+
+ if self._mapper is None:
+ self._mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ self.tensor_transform = ToTensor()
+
+ self.softmax = nn.Softmax(dim=2)
+
+ @torch.no_grad()
+ def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
+ src = self.network.preprocess_input(image)
+ memory = self.network.encoder(src)
+
+ confidence_of_predictions = []
+ trg_indices = [self.mapper(self.init_token)]
+
+ for _ in range(self.max_len):
+ trg = torch.tensor(trg_indices, device=self.device)[None, :].long()
+ trg, trg_mask = self.network.preprocess_target(trg)
+ logits = self.network.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+
+ # Convert logits to probabilities.
+ probs = self.softmax(logits)
+
+ pred_token = probs.argmax(2)[:, -1].item()
+ confidence = probs.max(2).values[:, -1].item()
+
+ trg_indices.append(pred_token)
+ confidence_of_predictions.append(confidence)
+
+ if pred_token == self.mapper(self.eos_token):
+ break
+
+ confidence = np.min(confidence_of_predictions)
+ predicted_characters = "".join([self.mapper(x) for x in trg_indices[1:]])
+
+ return predicted_characters, confidence
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+
+ predicted_characters, confidence_of_prediction = self._generate_sentence(image)
+
+ return predicted_characters, confidence_of_prediction