summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /text_recognizer/models
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/__init__.py18
-rw-r--r--text_recognizer/models/base.py455
-rw-r--r--text_recognizer/models/character_model.py88
-rw-r--r--text_recognizer/models/crnn_model.py119
-rw-r--r--text_recognizer/models/ctc_transformer_model.py120
-rw-r--r--text_recognizer/models/segmentation_model.py75
-rw-r--r--text_recognizer/models/transformer_model.py124
-rw-r--r--text_recognizer/models/vqvae_model.py80
8 files changed, 1079 insertions, 0 deletions
diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py
new file mode 100644
index 0000000..7647d7e
--- /dev/null
+++ b/text_recognizer/models/__init__.py
@@ -0,0 +1,18 @@
+"""Model modules."""
+from .base import Model
+from .character_model import CharacterModel
+from .crnn_model import CRNNModel
+from .ctc_transformer_model import CTCTransformerModel
+from .segmentation_model import SegmentationModel
+from .transformer_model import TransformerModel
+from .vqvae_model import VQVAEModel
+
+__all__ = [
+ "CharacterModel",
+ "CRNNModel",
+ "CTCTransformerModel",
+ "Model",
+ "SegmentationModel",
+ "TransformerModel",
+ "VQVAEModel",
+]
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
new file mode 100644
index 0000000..70f4cdb
--- /dev/null
+++ b/text_recognizer/models/base.py
@@ -0,0 +1,455 @@
+"""Abstract Model class for PyTorch neural networks."""
+
+from abc import ABC, abstractmethod
+from glob import glob
+import importlib
+from pathlib import Path
+import re
+import shutil
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+
+from loguru import logger
+import torch
+from torch import nn
+from torch import Tensor
+from torch.optim.swa_utils import AveragedModel, SWALR
+from torch.utils.data import DataLoader, Dataset, random_split
+from torchsummary import summary
+
+from text_recognizer import datasets
+from text_recognizer import networks
+from text_recognizer.datasets import EmnistMapper
+
+WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
+
+
+class Model(ABC):
+ """Abstract Model class with composition of different parts defining a PyTorch neural network."""
+
+ def __init__(
+ self,
+ network_fn: str,
+ dataset: str,
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ """Base class, to be inherited by model for specific type of data.
+
+ Args:
+ network_fn (str): The name of network.
+ dataset (str): The name dataset class.
+ network_args (Optional[Dict]): Arguments for the network. Defaults to None.
+ dataset_args (Optional[Dict]): Arguments for the dataset.
+ metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
+ criterion (Optional[Callable]): The criterion to evaluate the performance of the network.
+ Defaults to None.
+ criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None.
+ optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None.
+ optimizer_args (Optional[Dict]): Dict of arguments for optimizer. Defaults to None.
+ lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None.
+ lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to
+ None.
+ swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to
+ None.
+ device (Optional[str]): Name of the device to train on. Defaults to None.
+
+ """
+ self._name = f"{self.__class__.__name__}_{dataset}_{network_fn}"
+ # Has to be set in subclass.
+ self._mapper = None
+
+ # Placeholder.
+ self._input_shape = None
+
+ self.dataset_name = dataset
+ self.dataset = None
+ self.dataset_args = dataset_args
+
+ # Placeholders for datasets.
+ self.train_dataset = None
+ self.val_dataset = None
+ self.test_dataset = None
+
+ # Stochastic Weight Averaging placeholders.
+ self.swa_args = swa_args
+ self._swa_scheduler = None
+ self._swa_network = None
+ self._use_swa_model = False
+
+ # Experiment directory.
+ self.model_dir = None
+
+ # Flag for configured model.
+ self.is_configured = False
+ self.data_prepared = False
+
+ # Flag for stopping training.
+ self.stop_training = False
+
+ self._metrics = metrics if metrics is not None else None
+
+ # Set the device.
+ self._device = (
+ torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ if device is None
+ else device
+ )
+
+ # Configure network.
+ self._network = None
+ self._network_args = network_args
+ self._configure_network(network_fn)
+
+ # Place network on device (GPU).
+ self.to_device()
+
+ # Loss and Optimizer placeholders for before loading.
+ self._criterion = criterion
+ self.criterion_args = criterion_args
+
+ self._optimizer = optimizer
+ self.optimizer_args = optimizer_args
+
+ self._lr_scheduler = lr_scheduler
+ self.lr_scheduler_args = lr_scheduler_args
+
+ def configure_model(self) -> None:
+ """Configures criterion and optimizers."""
+ if not self.is_configured:
+ self._configure_criterion()
+ self._configure_optimizers()
+
+ # Set this flag to true to prevent the model from configuring again.
+ self.is_configured = True
+
+ def prepare_data(self) -> None:
+ """Prepare data for training."""
+ # TODO add downloading.
+ if not self.data_prepared:
+ # Load dataset module.
+ self.dataset = getattr(datasets, self.dataset_name)
+
+ # Load train dataset.
+ train_dataset = self.dataset(train=True, **self.dataset_args["args"])
+ train_dataset.load_or_generate_data()
+
+ # Set input shape.
+ self._input_shape = train_dataset.input_shape
+
+ # Split train dataset into a training and validation partition.
+ dataset_len = len(train_dataset)
+ train_len = int(
+ self.dataset_args["train_args"]["train_fraction"] * dataset_len
+ )
+ val_len = dataset_len - train_len
+ self.train_dataset, self.val_dataset = random_split(
+ train_dataset, lengths=[train_len, val_len]
+ )
+
+ # Load test dataset.
+ self.test_dataset = self.dataset(train=False, **self.dataset_args["args"])
+ self.test_dataset.load_or_generate_data()
+
+ # Set the flag to true to disable ability to load data again.
+ self.data_prepared = True
+
+ def train_dataloader(self) -> DataLoader:
+ """Returns data loader for training set."""
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=True,
+ pin_memory=True,
+ )
+
+ def val_dataloader(self) -> DataLoader:
+ """Returns data loader for validation set."""
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=True,
+ pin_memory=True,
+ )
+
+ def test_dataloader(self) -> DataLoader:
+ """Returns data loader for test set."""
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=False,
+ pin_memory=True,
+ )
+
+ def _configure_network(self, network_fn: Type[nn.Module]) -> None:
+ """Loads the network."""
+ # If no network arguments are given, load pretrained weights if they exist.
+ # Load network module.
+ network_fn = getattr(networks, network_fn)
+ if self._network_args is None:
+ self.load_weights(network_fn)
+ else:
+ self._network = network_fn(**self._network_args)
+
+ def _configure_criterion(self) -> None:
+ """Loads the criterion."""
+ self._criterion = (
+ self._criterion(**self.criterion_args)
+ if self._criterion is not None
+ else None
+ )
+
+ def _configure_optimizers(self,) -> None:
+ """Loads the optimizers."""
+ if self._optimizer is not None:
+ self._optimizer = self._optimizer(
+ self._network.parameters(), **self.optimizer_args
+ )
+ else:
+ self._optimizer = None
+
+ if self._optimizer and self._lr_scheduler is not None:
+ if "steps_per_epoch" in self.lr_scheduler_args:
+ self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader())
+
+ # Assume lr scheduler should update at each epoch if not specified.
+ if "interval" not in self.lr_scheduler_args:
+ interval = "epoch"
+ else:
+ interval = self.lr_scheduler_args.pop("interval")
+ self._lr_scheduler = {
+ "lr_scheduler": self._lr_scheduler(
+ self._optimizer, **self.lr_scheduler_args
+ ),
+ "interval": interval,
+ }
+
+ if self.swa_args is not None:
+ self._swa_scheduler = {
+ "swa_scheduler": SWALR(self._optimizer, swa_lr=self.swa_args["lr"]),
+ "swa_start": self.swa_args["start"],
+ }
+ self._swa_network = AveragedModel(self._network).to(self.device)
+
+ @property
+ def name(self) -> str:
+ """Returns the name of the model."""
+ return self._name
+
+ @property
+ def input_shape(self) -> Tuple[int, ...]:
+ """The input shape."""
+ return self._input_shape
+
+ @property
+ def mapper(self) -> EmnistMapper:
+ """Returns the mapper that maps between ints and chars."""
+ return self._mapper
+
+ @property
+ def mapping(self) -> Dict:
+ """Returns the mapping between network output and Emnist character."""
+ return self._mapper.mapping if self._mapper is not None else None
+
+ def eval(self) -> None:
+ """Sets the network to evaluation mode."""
+ self._network.eval()
+
+ def train(self) -> None:
+ """Sets the network to train mode."""
+ self._network.train()
+
+ @property
+ def device(self) -> str:
+ """Device where the weights are stored, i.e. cpu or cuda."""
+ return self._device
+
+ @property
+ def metrics(self) -> Optional[Dict]:
+ """Metrics."""
+ return self._metrics
+
+ @property
+ def criterion(self) -> Optional[Callable]:
+ """Criterion."""
+ return self._criterion
+
+ @property
+ def optimizer(self) -> Optional[Callable]:
+ """Optimizer."""
+ return self._optimizer
+
+ @property
+ def lr_scheduler(self) -> Optional[Dict]:
+ """Returns a directory with the learning rate scheduler."""
+ return self._lr_scheduler
+
+ @property
+ def swa_scheduler(self) -> Optional[Dict]:
+ """Returns a directory with the stochastic weight averaging scheduler."""
+ return self._swa_scheduler
+
+ @property
+ def swa_network(self) -> Optional[Callable]:
+ """Returns the stochastic weight averaging network."""
+ return self._swa_network
+
+ @property
+ def network(self) -> Type[nn.Module]:
+ """Neural network."""
+ # Returns the SWA network if available.
+ return self._network
+
+ @property
+ def weights_filename(self) -> str:
+ """Filepath to the network weights."""
+ WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True)
+ return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt")
+
+ def use_swa_model(self) -> None:
+ """Set to use predictions from SWA model."""
+ if self.swa_network is not None:
+ self._use_swa_model = True
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Feedforward pass with the network."""
+ if self._use_swa_model:
+ return self.swa_network(x)
+ else:
+ return self.network(x)
+
+ def summary(
+ self,
+ input_shape: Optional[Union[List, Tuple]] = None,
+ depth: int = 3,
+ device: Optional[str] = None,
+ ) -> None:
+ """Prints a summary of the network architecture."""
+ device = self.device if device is None else device
+
+ if input_shape is not None:
+ summary(self.network, input_shape, depth=depth, device=device)
+ elif self._input_shape is not None:
+ input_shape = tuple(self._input_shape)
+ summary(self.network, input_shape, depth=depth, device=device)
+ else:
+ logger.warning("Could not print summary as input shape is not set.")
+
+ def to_device(self) -> None:
+ """Places the network on the device (GPU)."""
+ self._network.to(self._device)
+
+ def _get_state_dict(self) -> Dict:
+ """Get the state dict of the model."""
+ state = {"model_state": self._network.state_dict()}
+
+ if self._optimizer is not None:
+ state["optimizer_state"] = self._optimizer.state_dict()
+
+ if self._lr_scheduler is not None:
+ state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict()
+ state["scheduler_interval"] = self._lr_scheduler["interval"]
+
+ if self._swa_network is not None:
+ state["swa_network"] = self._swa_network.state_dict()
+
+ return state
+
+ def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None:
+ """Load a previously saved checkpoint.
+
+ Args:
+ checkpoint_path (Path): Path to the experiment with the checkpoint.
+
+ """
+ checkpoint_path = Path(checkpoint_path)
+ self.prepare_data()
+ self.configure_model()
+ logger.debug("Loading checkpoint...")
+ if not checkpoint_path.exists():
+ logger.debug("File does not exist {str(checkpoint_path)}")
+
+ checkpoint = torch.load(str(checkpoint_path), map_location=self.device)
+ self._network.load_state_dict(checkpoint["model_state"])
+
+ if self._optimizer is not None:
+ self._optimizer.load_state_dict(checkpoint["optimizer_state"])
+
+ if self._lr_scheduler is not None:
+ # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs
+ # with OneCycleLR.
+ if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR":
+ self._lr_scheduler["lr_scheduler"].load_state_dict(
+ checkpoint["scheduler_state"]
+ )
+ self._lr_scheduler["interval"] = checkpoint["scheduler_interval"]
+
+ if self._swa_network is not None:
+ self._swa_network.load_state_dict(checkpoint["swa_network"])
+
+ def save_checkpoint(
+ self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str
+ ) -> None:
+ """Saves a checkpoint of the model.
+
+ Args:
+ checkpoint_path (Path): Path to the experiment with the checkpoint.
+ is_best (bool): If it is the currently best model.
+ epoch (int): The epoch of the checkpoint.
+ val_metric (str): Validation metric.
+
+ """
+ state = self._get_state_dict()
+ state["is_best"] = is_best
+ state["epoch"] = epoch
+ state["network_args"] = self._network_args
+
+ checkpoint_path.mkdir(parents=True, exist_ok=True)
+
+ logger.debug("Saving checkpoint...")
+ filepath = str(checkpoint_path / "last.pt")
+ torch.save(state, filepath)
+
+ if is_best:
+ logger.debug(
+ f"Found a new best {val_metric}. Saving best checkpoint and weights."
+ )
+ shutil.copyfile(filepath, str(checkpoint_path / "best.pt"))
+
+ def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None:
+ """Load the network weights."""
+ logger.debug("Loading network with pretrained weights.")
+ filename = glob(self.weights_filename)[0]
+ if not filename:
+ raise FileNotFoundError(
+ f"Could not find any pretrained weights at {self.weights_filename}"
+ )
+ # Loading state directory.
+ state_dict = torch.load(filename, map_location=torch.device(self._device))
+ self._network_args = state_dict["network_args"]
+ weights = state_dict["model_state"]
+
+ # Initializes the network with trained weights.
+ if network_fn is not None:
+ self._network = network_fn(**self._network_args)
+ self._network.load_state_dict(weights)
+
+ if "swa_network" in state_dict:
+ self._swa_network = AveragedModel(self._network).to(self.device)
+ self._swa_network.load_state_dict(state_dict["swa_network"])
+
+ def save_weights(self, path: Path) -> None:
+ """Save the network weights."""
+ logger.debug("Saving the best network weights.")
+ shutil.copyfile(str(path / "best.pt"), self.weights_filename)
diff --git a/text_recognizer/models/character_model.py b/text_recognizer/models/character_model.py
new file mode 100644
index 0000000..f9944f3
--- /dev/null
+++ b/text_recognizer/models/character_model.py
@@ -0,0 +1,88 @@
+"""Defines the CharacterModel class."""
+from typing import Callable, Dict, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+
+
+class CharacterModel(Model):
+ """Model for predicting characters from images."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ """Initializes the CharacterModel."""
+
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.pad_token = dataset_args["args"]["pad_token"]
+ if self._mapper is None:
+ self._mapper = EmnistMapper(pad_token=self.pad_token,)
+ self.tensor_transform = ToTensor()
+ self.softmax = nn.Softmax(dim=0)
+
+ @torch.no_grad()
+ def predict_on_image(
+ self, image: Union[np.ndarray, torch.Tensor]
+ ) -> Tuple[str, float]:
+ """Character prediction on an image.
+
+ Args:
+ image (Union[np.ndarray, torch.Tensor]): An image containing a character.
+
+ Returns:
+ Tuple[str, float]: The predicted character and the confidence in the prediction.
+
+ """
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ logits = self.forward(image)
+
+ prediction = self.softmax(logits.squeeze(0))
+
+ index = int(torch.argmax(prediction, dim=0))
+ confidence_of_prediction = prediction[index]
+ predicted_character = self.mapper(index)
+
+ return predicted_character, confidence_of_prediction
diff --git a/text_recognizer/models/crnn_model.py b/text_recognizer/models/crnn_model.py
new file mode 100644
index 0000000..1e01a83
--- /dev/null
+++ b/text_recognizer/models/crnn_model.py
@@ -0,0 +1,119 @@
+"""Defines the CRNNModel class."""
+from typing import Callable, Dict, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+from text_recognizer.networks import greedy_decoder
+
+
+class CRNNModel(Model):
+ """Model for predicting a sequence of characters from an image of a text line."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+
+ self.pad_token = dataset_args["args"]["pad_token"]
+ if self._mapper is None:
+ self._mapper = EmnistMapper(pad_token=self.pad_token,)
+ self.tensor_transform = ToTensor()
+
+ def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
+ """Computes the CTC loss.
+
+ Args:
+ output (Tensor): Model predictions.
+ targets (Tensor): Correct output sequence.
+
+ Returns:
+ Tensor: The CTC loss.
+
+ """
+
+ # Input lengths on the form [T, B]
+ input_lengths = torch.full(
+ size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,
+ )
+
+ # Configure target tensors for ctc loss.
+ targets_ = Tensor([]).to(self.device)
+ target_lengths = []
+ for t in targets:
+ # Remove padding symbol as it acts as the blank symbol.
+ t = t[t < 79]
+ targets_ = torch.cat([targets_, t])
+ target_lengths.append(len(t))
+
+ targets = targets_.type(dtype=torch.long)
+ target_lengths = (
+ torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
+ )
+
+ return self._criterion(output, targets, input_lengths, target_lengths)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ log_probs = self.forward(image)
+
+ raw_pred, _ = greedy_decoder(
+ predictions=log_probs,
+ character_mapper=self.mapper,
+ blank_label=79,
+ collapse_repeated=True,
+ )
+
+ log_probs, _ = log_probs.max(dim=2)
+
+ predicted_characters = "".join(raw_pred[0])
+ confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item()
+
+ return predicted_characters, confidence_of_prediction
diff --git a/text_recognizer/models/ctc_transformer_model.py b/text_recognizer/models/ctc_transformer_model.py
new file mode 100644
index 0000000..25925f2
--- /dev/null
+++ b/text_recognizer/models/ctc_transformer_model.py
@@ -0,0 +1,120 @@
+"""Defines the CTC Transformer Model class."""
+from typing import Callable, Dict, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+from text_recognizer.networks import greedy_decoder
+
+
+class CTCTransformerModel(Model):
+ """Model for predicting a sequence of characters from an image of a text line."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.pad_token = dataset_args["args"]["pad_token"]
+ self.lower = dataset_args["args"]["lower"]
+
+ if self._mapper is None:
+ self._mapper = EmnistMapper(pad_token=self.pad_token, lower=self.lower,)
+
+ self.tensor_transform = ToTensor()
+
+ def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
+ """Computes the CTC loss.
+
+ Args:
+ output (Tensor): Model predictions.
+ targets (Tensor): Correct output sequence.
+
+ Returns:
+ Tensor: The CTC loss.
+
+ """
+ # Input lengths on the form [T, B]
+ input_lengths = torch.full(
+ size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,
+ )
+
+ # Configure target tensors for ctc loss.
+ targets_ = Tensor([]).to(self.device)
+ target_lengths = []
+ for t in targets:
+ # Remove padding symbol as it acts as the blank symbol.
+ t = t[t < 53]
+ targets_ = torch.cat([targets_, t])
+ target_lengths.append(len(t))
+
+ targets = targets_.type(dtype=torch.long)
+ target_lengths = (
+ torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
+ )
+
+ return self._criterion(output, targets, input_lengths, target_lengths)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ log_probs = self.forward(image)
+
+ raw_pred, _ = greedy_decoder(
+ predictions=log_probs,
+ character_mapper=self.mapper,
+ blank_label=53,
+ collapse_repeated=True,
+ )
+
+ log_probs, _ = log_probs.max(dim=2)
+
+ predicted_characters = "".join(raw_pred[0])
+ confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item()
+
+ return predicted_characters, confidence_of_prediction
diff --git a/text_recognizer/models/segmentation_model.py b/text_recognizer/models/segmentation_model.py
new file mode 100644
index 0000000..613108a
--- /dev/null
+++ b/text_recognizer/models/segmentation_model.py
@@ -0,0 +1,75 @@
+"""Segmentation model for detecting and segmenting lines."""
+from typing import Callable, Dict, Optional, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.models.base import Model
+
+
+class SegmentationModel(Model):
+ """Model for segmenting lines in an image."""
+
+ def __init__(
+ self,
+ network_fn: str,
+ dataset: str,
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.tensor_transform = ToTensor()
+ self.softmax = nn.Softmax(dim=2)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype is np.uint8:
+ # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype is torch.uint8 or image.dtype is torch.int64:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ if not torch.is_tensor(image):
+ image = Tensor(image)
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+
+ logits = self.forward(image)
+
+ segmentation_mask = torch.argmax(logits, dim=1)
+
+ return segmentation_mask
diff --git a/text_recognizer/models/transformer_model.py b/text_recognizer/models/transformer_model.py
new file mode 100644
index 0000000..3f63053
--- /dev/null
+++ b/text_recognizer/models/transformer_model.py
@@ -0,0 +1,124 @@
+"""Defines the CNN-Transformer class."""
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+
+from text_recognizer.datasets import EmnistMapper
+import text_recognizer.datasets.transforms as transforms
+from text_recognizer.models.base import Model
+from text_recognizer.networks import greedy_decoder
+
+
+class TransformerModel(Model):
+ """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer."""
+
+ def __init__(
+ self,
+ network_fn: str,
+ dataset: str,
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.init_token = dataset_args["args"]["init_token"]
+ self.pad_token = dataset_args["args"]["pad_token"]
+ self.eos_token = dataset_args["args"]["eos_token"]
+ self.lower = dataset_args["args"]["lower"]
+ self.max_len = 100
+
+ if self._mapper is None:
+ self._mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ lower=self.lower,
+ )
+ self.tensor_transform = transforms.Compose(
+ [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])]
+ )
+ self.softmax = nn.Softmax(dim=2)
+
+ @torch.no_grad()
+ def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
+ src = self.network.extract_image_features(image)
+
+ # Added for vqvae transformer.
+ if isinstance(src, Tuple):
+ src = src[0]
+
+ memory = self.network.encoder(src)
+
+ confidence_of_predictions = []
+ trg_indices = [self.mapper(self.init_token)]
+
+ for _ in range(self.max_len - 1):
+ trg = torch.tensor(trg_indices, device=self.device)[None, :].long()
+ trg = self.network.target_embedding(trg)
+ logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None)
+
+ # Convert logits to probabilities.
+ probs = self.softmax(logits)
+
+ pred_token = probs.argmax(2)[:, -1].item()
+ confidence = probs.max(2).values[:, -1].item()
+
+ trg_indices.append(pred_token)
+ confidence_of_predictions.append(confidence)
+
+ if pred_token == self.mapper(self.eos_token):
+ break
+
+ confidence = np.min(confidence_of_predictions)
+ predicted_characters = "".join([self.mapper(x) for x in trg_indices[1:]])
+
+ return predicted_characters, confidence
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+
+ (predicted_characters, confidence_of_prediction,) = self._generate_sentence(
+ image
+ )
+
+ return predicted_characters, confidence_of_prediction
diff --git a/text_recognizer/models/vqvae_model.py b/text_recognizer/models/vqvae_model.py
new file mode 100644
index 0000000..70f6f1f
--- /dev/null
+++ b/text_recognizer/models/vqvae_model.py
@@ -0,0 +1,80 @@
+"""Defines the VQVAEModel class."""
+from typing import Callable, Dict, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+
+
+class VQVAEModel(Model):
+ """Model for reconstructing images from codebook."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ """Initializes the CharacterModel."""
+
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.pad_token = dataset_args["args"]["pad_token"]
+ if self._mapper is None:
+ self._mapper = EmnistMapper(pad_token=self.pad_token,)
+ self.tensor_transform = ToTensor()
+ self.softmax = nn.Softmax(dim=0)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ """Reconstruction of image.
+
+ Args:
+ image (Union[np.ndarray, torch.Tensor]): An image containing a character.
+
+ Returns:
+ Tuple[str, float]: The predicted character and the confidence in the prediction.
+
+ """
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ image_reconstructed, _ = self.forward(image)
+
+ return image_reconstructed