summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r--src/text_recognizer/models/__init__.py16
-rw-r--r--src/text_recognizer/models/base.py66
-rw-r--r--src/text_recognizer/models/character_model.py4
-rw-r--r--src/text_recognizer/models/crnn_model.py (renamed from src/text_recognizer/models/line_ctc_model.py)16
-rw-r--r--src/text_recognizer/models/metrics.py5
-rw-r--r--src/text_recognizer/models/transformer_encoder_model.py111
-rw-r--r--src/text_recognizer/models/vision_transformer_model.py119
7 files changed, 311 insertions, 26 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index a3cfc15..28aa52e 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -1,7 +1,19 @@
"""Model modules."""
from .base import Model
from .character_model import CharacterModel
-from .line_ctc_model import LineCTCModel
+from .crnn_model import CRNNModel
from .metrics import accuracy, cer, wer
+from .transformer_encoder_model import TransformerEncoderModel
+from .vision_transformer_model import VisionTransformerModel
-__all__ = ["Model", "cer", "CharacterModel", "LineCTCModel", "accuracy", "wer"]
+__all__ = [
+ "Model",
+ "cer",
+ "CharacterModel",
+ "CRNNModel",
+ "CNNTransfromerModel",
+ "accuracy",
+ "TransformerEncoderModel",
+ "VisionTransformerModel",
+ "wer",
+]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index caf8065..cc44c92 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -6,7 +6,7 @@ import importlib
from pathlib import Path
import re
import shutil
-from typing import Callable, Dict, Optional, Tuple, Type
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from loguru import logger
import torch
@@ -15,6 +15,7 @@ from torch import Tensor
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.utils.data import DataLoader, Dataset, random_split
from torchsummary import summary
+from torchvision.transforms import Compose
from text_recognizer.datasets import EmnistMapper
@@ -128,16 +129,42 @@ class Model(ABC):
self._configure_criterion()
self._configure_optimizers()
- # Prints a summary of the network in terminal.
- self.summary()
-
# Set this flag to true to prevent the model from configuring again.
self.is_configured = True
+ def _configure_transforms(self) -> None:
+ # Load transforms.
+ transforms_module = importlib.import_module(
+ "text_recognizer.datasets.transforms"
+ )
+ if (
+ "transform" in self.dataset_args["args"]
+ and self.dataset_args["args"]["transform"] is not None
+ ):
+ transform_ = []
+ for t in self.dataset_args["args"]["transform"]:
+ args = t["args"] or {}
+ transform_.append(getattr(transforms_module, t["type"])(**args))
+ self.dataset_args["args"]["transform"] = Compose(transform_)
+
+ if (
+ "target_transform" in self.dataset_args["args"]
+ and self.dataset_args["args"]["target_transform"] is not None
+ ):
+ target_transform_ = [
+ torch.tensor,
+ ]
+ for t in self.dataset_args["args"]["target_transform"]:
+ args = t["args"] or {}
+ target_transform_.append(getattr(transforms_module, t["type"])(**args))
+ self.dataset_args["args"]["target_transform"] = Compose(target_transform_)
+
def prepare_data(self) -> None:
"""Prepare data for training."""
# TODO add downloading.
if not self.data_prepared:
+ self._configure_transforms()
+
# Load train dataset.
train_dataset = self.dataset(train=True, **self.dataset_args["args"])
train_dataset.load_or_generate_data()
@@ -327,20 +354,20 @@ class Model(ABC):
else:
return self.network(x)
- def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
- """Compute the loss."""
- return self.criterion(output, targets)
-
def summary(
- self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 3
+ self,
+ input_shape: Optional[Union[List, Tuple]] = None,
+ depth: int = 4,
+ device: Optional[str] = None,
) -> None:
"""Prints a summary of the network architecture."""
+ device = self.device if device is None else device
if input_shape is not None:
- summary(self.network, input_shape, depth=depth, device=self.device)
+ summary(self.network, input_shape, depth=depth, device=device)
elif self._input_shape is not None:
input_shape = (1,) + tuple(self._input_shape)
- summary(self.network, input_shape, depth=depth, device=self.device)
+ summary(self.network, input_shape, depth=depth, device=device)
else:
logger.warning("Could not print summary as input shape is not set.")
@@ -356,25 +383,29 @@ class Model(ABC):
state["optimizer_state"] = self._optimizer.state_dict()
if self._lr_scheduler is not None:
- state["scheduler_state"] = self._lr_scheduler.state_dict()
+ state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict()
+ state["scheduler_interval"] = self._lr_scheduler["interval"]
if self._swa_network is not None:
state["swa_network"] = self._swa_network.state_dict()
return state
- def load_from_checkpoint(self, checkpoint_path: Path) -> None:
+ def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None:
"""Load a previously saved checkpoint.
Args:
checkpoint_path (Path): Path to the experiment with the checkpoint.
"""
+ checkpoint_path = Path(checkpoint_path)
+ self.prepare_data()
+ self.configure_model()
logger.debug("Loading checkpoint...")
if not checkpoint_path.exists():
logger.debug("File does not exist {str(checkpoint_path)}")
- checkpoint = torch.load(str(checkpoint_path))
+ checkpoint = torch.load(str(checkpoint_path), map_location=self.device)
self._network.load_state_dict(checkpoint["model_state"])
if self._optimizer is not None:
@@ -383,8 +414,11 @@ class Model(ABC):
if self._lr_scheduler is not None:
# Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs
# with OneCycleLR.
- if self._lr_scheduler.__class__.__name__ != "OneCycleLR":
- self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
+ if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR":
+ self._lr_scheduler["lr_scheduler"].load_state_dict(
+ checkpoint["scheduler_state"]
+ )
+ self._lr_scheduler["interval"] = checkpoint["scheduler_interval"]
if self._swa_network is not None:
self._swa_network.load_state_dict(checkpoint["swa_network"])
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 50e94a2..f9944f3 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -47,8 +47,9 @@ class CharacterModel(Model):
swa_args,
device,
)
+ self.pad_token = dataset_args["args"]["pad_token"]
if self._mapper is None:
- self._mapper = EmnistMapper()
+ self._mapper = EmnistMapper(pad_token=self.pad_token,)
self.tensor_transform = ToTensor()
self.softmax = nn.Softmax(dim=0)
@@ -65,6 +66,7 @@ class CharacterModel(Model):
Tuple[str, float]: The predicted character and the confidence in the prediction.
"""
+ self.eval()
if image.dtype == np.uint8:
# Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/crnn_model.py
index 16eaed3..1e01a83 100644
--- a/src/text_recognizer/models/line_ctc_model.py
+++ b/src/text_recognizer/models/crnn_model.py
@@ -1,4 +1,4 @@
-"""Defines the LineCTCModel class."""
+"""Defines the CRNNModel class."""
from typing import Callable, Dict, Optional, Tuple, Type, Union
import numpy as np
@@ -13,7 +13,7 @@ from text_recognizer.models.base import Model
from text_recognizer.networks import greedy_decoder
-class LineCTCModel(Model):
+class CRNNModel(Model):
"""Model for predicting a sequence of characters from an image of a text line."""
def __init__(
@@ -47,11 +47,13 @@ class LineCTCModel(Model):
swa_args,
device,
)
+
+ self.pad_token = dataset_args["args"]["pad_token"]
if self._mapper is None:
- self._mapper = EmnistMapper()
+ self._mapper = EmnistMapper(pad_token=self.pad_token,)
self.tensor_transform = ToTensor()
- def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
+ def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
"""Computes the CTC loss.
Args:
@@ -82,11 +84,13 @@ class LineCTCModel(Model):
torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
)
- return self.criterion(output, targets, input_lengths, target_lengths)
+ return self._criterion(output, targets, input_lengths, target_lengths)
@torch.no_grad()
def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
"""Predict on a single input."""
+ self.eval()
+
if image.dtype == np.uint8:
# Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
image = self.tensor_transform(image)
@@ -110,6 +114,6 @@ class LineCTCModel(Model):
log_probs, _ = log_probs.max(dim=2)
predicted_characters = "".join(raw_pred[0])
- confidence_of_prediction = torch.exp(log_probs.sum()).item()
+ confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item()
return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py
index 6a26216..42c3c6e 100644
--- a/src/text_recognizer/models/metrics.py
+++ b/src/text_recognizer/models/metrics.py
@@ -17,7 +17,10 @@ def accuracy(outputs: Tensor, labels: Tensor) -> float:
float: The accuracy for the batch.
"""
- _, predicted = torch.max(outputs.data, dim=1)
+ # eos_index = torch.nonzero(labels == eos, as_tuple=False)
+ # eos_index = eos_index[0].item() if eos_index.nelement() else -1
+
+ _, predicted = torch.max(outputs, dim=-1)
acc = (predicted == labels).sum().float() / labels.shape[0]
acc = acc.item()
return acc
diff --git a/src/text_recognizer/models/transformer_encoder_model.py b/src/text_recognizer/models/transformer_encoder_model.py
new file mode 100644
index 0000000..e35e298
--- /dev/null
+++ b/src/text_recognizer/models/transformer_encoder_model.py
@@ -0,0 +1,111 @@
+"""Defines the CNN-Transformer class."""
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+
+
+class TransformerEncoderModel(Model):
+ """A class for only using the encoder part in the sequence modelling."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ # self.init_token = dataset_args["args"]["init_token"]
+ self.pad_token = dataset_args["args"]["pad_token"]
+ self.eos_token = dataset_args["args"]["eos_token"]
+ if network_args is not None:
+ self.max_len = network_args["max_len"]
+ else:
+ self.max_len = 128
+
+ if self._mapper is None:
+ self._mapper = EmnistMapper(
+ # init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ self.tensor_transform = ToTensor()
+
+ self.softmax = nn.Softmax(dim=2)
+
+ @torch.no_grad()
+ def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
+ logits = self.network(image)
+ # Convert logits to probabilities.
+ probs = self.softmax(logits).squeeze(0)
+
+ confidence, pred_tokens = probs.max(1)
+ pred_tokens = pred_tokens
+
+ eos_index = torch.nonzero(
+ pred_tokens == self._mapper(self.eos_token), as_tuple=False,
+ )
+
+ eos_index = eos_index[0].item() if eos_index.nelement() else -1
+
+ predicted_characters = "".join(
+ [self.mapper(x) for x in pred_tokens[:eos_index].tolist()]
+ )
+
+ confidence = np.min(confidence.tolist())
+
+ return predicted_characters, confidence
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+
+ (predicted_characters, confidence_of_prediction,) = self._generate_sentence(
+ image
+ )
+
+ return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/vision_transformer_model.py b/src/text_recognizer/models/vision_transformer_model.py
new file mode 100644
index 0000000..3d36437
--- /dev/null
+++ b/src/text_recognizer/models/vision_transformer_model.py
@@ -0,0 +1,119 @@
+"""Defines the CNN-Transformer class."""
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+from text_recognizer.networks import greedy_decoder
+
+
+class VisionTransformerModel(Model):
+ """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.init_token = dataset_args["args"]["init_token"]
+ self.pad_token = dataset_args["args"]["pad_token"]
+ self.eos_token = dataset_args["args"]["eos_token"]
+ if network_args is not None:
+ self.max_len = network_args["max_len"]
+ else:
+ self.max_len = 120
+
+ if self._mapper is None:
+ self._mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ self.tensor_transform = ToTensor()
+
+ self.softmax = nn.Softmax(dim=2)
+
+ @torch.no_grad()
+ def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
+ src = self.network.preprocess_input(image)
+ memory = self.network.encoder(src)
+
+ confidence_of_predictions = []
+ trg_indices = [self.mapper(self.init_token)]
+
+ for _ in range(self.max_len - 1):
+ trg = torch.tensor(trg_indices, device=self.device)[None, :].long()
+ trg = self.network.preprocess_target(trg)
+ logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None)
+
+ # Convert logits to probabilities.
+ probs = self.softmax(logits)
+
+ pred_token = probs.argmax(2)[:, -1].item()
+ confidence = probs.max(2).values[:, -1].item()
+
+ trg_indices.append(pred_token)
+ confidence_of_predictions.append(confidence)
+
+ if pred_token == self.mapper(self.eos_token):
+ break
+
+ confidence = np.min(confidence_of_predictions)
+ predicted_characters = "".join([self.mapper(x) for x in trg_indices[1:]])
+
+ return predicted_characters, confidence
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+
+ (predicted_characters, confidence_of_prediction,) = self._generate_sentence(
+ image
+ )
+
+ return predicted_characters, confidence_of_prediction