From 7e8e54e84c63171e748bbf09516fd517e6821ace Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 20 Mar 2021 18:09:06 +0100 Subject: Inital commit for refactoring to lightning --- text_recognizer/models/__init__.py | 18 + text_recognizer/models/base.py | 455 ++++++++++++++++++++++++ text_recognizer/models/character_model.py | 88 +++++ text_recognizer/models/crnn_model.py | 119 +++++++ text_recognizer/models/ctc_transformer_model.py | 120 +++++++ text_recognizer/models/segmentation_model.py | 75 ++++ text_recognizer/models/transformer_model.py | 124 +++++++ text_recognizer/models/vqvae_model.py | 80 +++++ 8 files changed, 1079 insertions(+) create mode 100644 text_recognizer/models/__init__.py create mode 100644 text_recognizer/models/base.py create mode 100644 text_recognizer/models/character_model.py create mode 100644 text_recognizer/models/crnn_model.py create mode 100644 text_recognizer/models/ctc_transformer_model.py create mode 100644 text_recognizer/models/segmentation_model.py create mode 100644 text_recognizer/models/transformer_model.py create mode 100644 text_recognizer/models/vqvae_model.py (limited to 'text_recognizer/models') diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py new file mode 100644 index 0000000..7647d7e --- /dev/null +++ b/text_recognizer/models/__init__.py @@ -0,0 +1,18 @@ +"""Model modules.""" +from .base import Model +from .character_model import CharacterModel +from .crnn_model import CRNNModel +from .ctc_transformer_model import CTCTransformerModel +from .segmentation_model import SegmentationModel +from .transformer_model import TransformerModel +from .vqvae_model import VQVAEModel + +__all__ = [ + "CharacterModel", + "CRNNModel", + "CTCTransformerModel", + "Model", + "SegmentationModel", + "TransformerModel", + "VQVAEModel", +] diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py new file mode 100644 index 0000000..70f4cdb --- /dev/null +++ b/text_recognizer/models/base.py @@ -0,0 +1,455 @@ +"""Abstract Model class for PyTorch neural networks.""" + +from abc import ABC, abstractmethod +from glob import glob +import importlib +from pathlib import Path +import re +import shutil +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +from loguru import logger +import torch +from torch import nn +from torch import Tensor +from torch.optim.swa_utils import AveragedModel, SWALR +from torch.utils.data import DataLoader, Dataset, random_split +from torchsummary import summary + +from text_recognizer import datasets +from text_recognizer import networks +from text_recognizer.datasets import EmnistMapper + +WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" + + +class Model(ABC): + """Abstract Model class with composition of different parts defining a PyTorch neural network.""" + + def __init__( + self, + network_fn: str, + dataset: str, + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + """Base class, to be inherited by model for specific type of data. + + Args: + network_fn (str): The name of network. + dataset (str): The name dataset class. + network_args (Optional[Dict]): Arguments for the network. Defaults to None. + dataset_args (Optional[Dict]): Arguments for the dataset. + metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. + criterion (Optional[Callable]): The criterion to evaluate the performance of the network. + Defaults to None. + criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None. + optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None. + optimizer_args (Optional[Dict]): Dict of arguments for optimizer. Defaults to None. + lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None. + lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to + None. + swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to + None. + device (Optional[str]): Name of the device to train on. Defaults to None. + + """ + self._name = f"{self.__class__.__name__}_{dataset}_{network_fn}" + # Has to be set in subclass. + self._mapper = None + + # Placeholder. + self._input_shape = None + + self.dataset_name = dataset + self.dataset = None + self.dataset_args = dataset_args + + # Placeholders for datasets. + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + + # Stochastic Weight Averaging placeholders. + self.swa_args = swa_args + self._swa_scheduler = None + self._swa_network = None + self._use_swa_model = False + + # Experiment directory. + self.model_dir = None + + # Flag for configured model. + self.is_configured = False + self.data_prepared = False + + # Flag for stopping training. + self.stop_training = False + + self._metrics = metrics if metrics is not None else None + + # Set the device. + self._device = ( + torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device is None + else device + ) + + # Configure network. + self._network = None + self._network_args = network_args + self._configure_network(network_fn) + + # Place network on device (GPU). + self.to_device() + + # Loss and Optimizer placeholders for before loading. + self._criterion = criterion + self.criterion_args = criterion_args + + self._optimizer = optimizer + self.optimizer_args = optimizer_args + + self._lr_scheduler = lr_scheduler + self.lr_scheduler_args = lr_scheduler_args + + def configure_model(self) -> None: + """Configures criterion and optimizers.""" + if not self.is_configured: + self._configure_criterion() + self._configure_optimizers() + + # Set this flag to true to prevent the model from configuring again. + self.is_configured = True + + def prepare_data(self) -> None: + """Prepare data for training.""" + # TODO add downloading. + if not self.data_prepared: + # Load dataset module. + self.dataset = getattr(datasets, self.dataset_name) + + # Load train dataset. + train_dataset = self.dataset(train=True, **self.dataset_args["args"]) + train_dataset.load_or_generate_data() + + # Set input shape. + self._input_shape = train_dataset.input_shape + + # Split train dataset into a training and validation partition. + dataset_len = len(train_dataset) + train_len = int( + self.dataset_args["train_args"]["train_fraction"] * dataset_len + ) + val_len = dataset_len - train_len + self.train_dataset, self.val_dataset = random_split( + train_dataset, lengths=[train_len, val_len] + ) + + # Load test dataset. + self.test_dataset = self.dataset(train=False, **self.dataset_args["args"]) + self.test_dataset.load_or_generate_data() + + # Set the flag to true to disable ability to load data again. + self.data_prepared = True + + def train_dataloader(self) -> DataLoader: + """Returns data loader for training set.""" + return DataLoader( + self.train_dataset, + batch_size=self.dataset_args["train_args"]["batch_size"], + num_workers=self.dataset_args["train_args"]["num_workers"], + shuffle=True, + pin_memory=True, + ) + + def val_dataloader(self) -> DataLoader: + """Returns data loader for validation set.""" + return DataLoader( + self.val_dataset, + batch_size=self.dataset_args["train_args"]["batch_size"], + num_workers=self.dataset_args["train_args"]["num_workers"], + shuffle=True, + pin_memory=True, + ) + + def test_dataloader(self) -> DataLoader: + """Returns data loader for test set.""" + return DataLoader( + self.test_dataset, + batch_size=self.dataset_args["train_args"]["batch_size"], + num_workers=self.dataset_args["train_args"]["num_workers"], + shuffle=False, + pin_memory=True, + ) + + def _configure_network(self, network_fn: Type[nn.Module]) -> None: + """Loads the network.""" + # If no network arguments are given, load pretrained weights if they exist. + # Load network module. + network_fn = getattr(networks, network_fn) + if self._network_args is None: + self.load_weights(network_fn) + else: + self._network = network_fn(**self._network_args) + + def _configure_criterion(self) -> None: + """Loads the criterion.""" + self._criterion = ( + self._criterion(**self.criterion_args) + if self._criterion is not None + else None + ) + + def _configure_optimizers(self,) -> None: + """Loads the optimizers.""" + if self._optimizer is not None: + self._optimizer = self._optimizer( + self._network.parameters(), **self.optimizer_args + ) + else: + self._optimizer = None + + if self._optimizer and self._lr_scheduler is not None: + if "steps_per_epoch" in self.lr_scheduler_args: + self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader()) + + # Assume lr scheduler should update at each epoch if not specified. + if "interval" not in self.lr_scheduler_args: + interval = "epoch" + else: + interval = self.lr_scheduler_args.pop("interval") + self._lr_scheduler = { + "lr_scheduler": self._lr_scheduler( + self._optimizer, **self.lr_scheduler_args + ), + "interval": interval, + } + + if self.swa_args is not None: + self._swa_scheduler = { + "swa_scheduler": SWALR(self._optimizer, swa_lr=self.swa_args["lr"]), + "swa_start": self.swa_args["start"], + } + self._swa_network = AveragedModel(self._network).to(self.device) + + @property + def name(self) -> str: + """Returns the name of the model.""" + return self._name + + @property + def input_shape(self) -> Tuple[int, ...]: + """The input shape.""" + return self._input_shape + + @property + def mapper(self) -> EmnistMapper: + """Returns the mapper that maps between ints and chars.""" + return self._mapper + + @property + def mapping(self) -> Dict: + """Returns the mapping between network output and Emnist character.""" + return self._mapper.mapping if self._mapper is not None else None + + def eval(self) -> None: + """Sets the network to evaluation mode.""" + self._network.eval() + + def train(self) -> None: + """Sets the network to train mode.""" + self._network.train() + + @property + def device(self) -> str: + """Device where the weights are stored, i.e. cpu or cuda.""" + return self._device + + @property + def metrics(self) -> Optional[Dict]: + """Metrics.""" + return self._metrics + + @property + def criterion(self) -> Optional[Callable]: + """Criterion.""" + return self._criterion + + @property + def optimizer(self) -> Optional[Callable]: + """Optimizer.""" + return self._optimizer + + @property + def lr_scheduler(self) -> Optional[Dict]: + """Returns a directory with the learning rate scheduler.""" + return self._lr_scheduler + + @property + def swa_scheduler(self) -> Optional[Dict]: + """Returns a directory with the stochastic weight averaging scheduler.""" + return self._swa_scheduler + + @property + def swa_network(self) -> Optional[Callable]: + """Returns the stochastic weight averaging network.""" + return self._swa_network + + @property + def network(self) -> Type[nn.Module]: + """Neural network.""" + # Returns the SWA network if available. + return self._network + + @property + def weights_filename(self) -> str: + """Filepath to the network weights.""" + WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True) + return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt") + + def use_swa_model(self) -> None: + """Set to use predictions from SWA model.""" + if self.swa_network is not None: + self._use_swa_model = True + + def forward(self, x: Tensor) -> Tensor: + """Feedforward pass with the network.""" + if self._use_swa_model: + return self.swa_network(x) + else: + return self.network(x) + + def summary( + self, + input_shape: Optional[Union[List, Tuple]] = None, + depth: int = 3, + device: Optional[str] = None, + ) -> None: + """Prints a summary of the network architecture.""" + device = self.device if device is None else device + + if input_shape is not None: + summary(self.network, input_shape, depth=depth, device=device) + elif self._input_shape is not None: + input_shape = tuple(self._input_shape) + summary(self.network, input_shape, depth=depth, device=device) + else: + logger.warning("Could not print summary as input shape is not set.") + + def to_device(self) -> None: + """Places the network on the device (GPU).""" + self._network.to(self._device) + + def _get_state_dict(self) -> Dict: + """Get the state dict of the model.""" + state = {"model_state": self._network.state_dict()} + + if self._optimizer is not None: + state["optimizer_state"] = self._optimizer.state_dict() + + if self._lr_scheduler is not None: + state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict() + state["scheduler_interval"] = self._lr_scheduler["interval"] + + if self._swa_network is not None: + state["swa_network"] = self._swa_network.state_dict() + + return state + + def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None: + """Load a previously saved checkpoint. + + Args: + checkpoint_path (Path): Path to the experiment with the checkpoint. + + """ + checkpoint_path = Path(checkpoint_path) + self.prepare_data() + self.configure_model() + logger.debug("Loading checkpoint...") + if not checkpoint_path.exists(): + logger.debug("File does not exist {str(checkpoint_path)}") + + checkpoint = torch.load(str(checkpoint_path), map_location=self.device) + self._network.load_state_dict(checkpoint["model_state"]) + + if self._optimizer is not None: + self._optimizer.load_state_dict(checkpoint["optimizer_state"]) + + if self._lr_scheduler is not None: + # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs + # with OneCycleLR. + if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR": + self._lr_scheduler["lr_scheduler"].load_state_dict( + checkpoint["scheduler_state"] + ) + self._lr_scheduler["interval"] = checkpoint["scheduler_interval"] + + if self._swa_network is not None: + self._swa_network.load_state_dict(checkpoint["swa_network"]) + + def save_checkpoint( + self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str + ) -> None: + """Saves a checkpoint of the model. + + Args: + checkpoint_path (Path): Path to the experiment with the checkpoint. + is_best (bool): If it is the currently best model. + epoch (int): The epoch of the checkpoint. + val_metric (str): Validation metric. + + """ + state = self._get_state_dict() + state["is_best"] = is_best + state["epoch"] = epoch + state["network_args"] = self._network_args + + checkpoint_path.mkdir(parents=True, exist_ok=True) + + logger.debug("Saving checkpoint...") + filepath = str(checkpoint_path / "last.pt") + torch.save(state, filepath) + + if is_best: + logger.debug( + f"Found a new best {val_metric}. Saving best checkpoint and weights." + ) + shutil.copyfile(filepath, str(checkpoint_path / "best.pt")) + + def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None: + """Load the network weights.""" + logger.debug("Loading network with pretrained weights.") + filename = glob(self.weights_filename)[0] + if not filename: + raise FileNotFoundError( + f"Could not find any pretrained weights at {self.weights_filename}" + ) + # Loading state directory. + state_dict = torch.load(filename, map_location=torch.device(self._device)) + self._network_args = state_dict["network_args"] + weights = state_dict["model_state"] + + # Initializes the network with trained weights. + if network_fn is not None: + self._network = network_fn(**self._network_args) + self._network.load_state_dict(weights) + + if "swa_network" in state_dict: + self._swa_network = AveragedModel(self._network).to(self.device) + self._swa_network.load_state_dict(state_dict["swa_network"]) + + def save_weights(self, path: Path) -> None: + """Save the network weights.""" + logger.debug("Saving the best network weights.") + shutil.copyfile(str(path / "best.pt"), self.weights_filename) diff --git a/text_recognizer/models/character_model.py b/text_recognizer/models/character_model.py new file mode 100644 index 0000000..f9944f3 --- /dev/null +++ b/text_recognizer/models/character_model.py @@ -0,0 +1,88 @@ +"""Defines the CharacterModel class.""" +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model + + +class CharacterModel(Model): + """Model for predicting characters from images.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + """Initializes the CharacterModel.""" + + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + self.pad_token = dataset_args["args"]["pad_token"] + if self._mapper is None: + self._mapper = EmnistMapper(pad_token=self.pad_token,) + self.tensor_transform = ToTensor() + self.softmax = nn.Softmax(dim=0) + + @torch.no_grad() + def predict_on_image( + self, image: Union[np.ndarray, torch.Tensor] + ) -> Tuple[str, float]: + """Character prediction on an image. + + Args: + image (Union[np.ndarray, torch.Tensor]): An image containing a character. + + Returns: + Tuple[str, float]: The predicted character and the confidence in the prediction. + + """ + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + logits = self.forward(image) + + prediction = self.softmax(logits.squeeze(0)) + + index = int(torch.argmax(prediction, dim=0)) + confidence_of_prediction = prediction[index] + predicted_character = self.mapper(index) + + return predicted_character, confidence_of_prediction diff --git a/text_recognizer/models/crnn_model.py b/text_recognizer/models/crnn_model.py new file mode 100644 index 0000000..1e01a83 --- /dev/null +++ b/text_recognizer/models/crnn_model.py @@ -0,0 +1,119 @@ +"""Defines the CRNNModel class.""" +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model +from text_recognizer.networks import greedy_decoder + + +class CRNNModel(Model): + """Model for predicting a sequence of characters from an image of a text line.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + + self.pad_token = dataset_args["args"]["pad_token"] + if self._mapper is None: + self._mapper = EmnistMapper(pad_token=self.pad_token,) + self.tensor_transform = ToTensor() + + def criterion(self, output: Tensor, targets: Tensor) -> Tensor: + """Computes the CTC loss. + + Args: + output (Tensor): Model predictions. + targets (Tensor): Correct output sequence. + + Returns: + Tensor: The CTC loss. + + """ + + # Input lengths on the form [T, B] + input_lengths = torch.full( + size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, + ) + + # Configure target tensors for ctc loss. + targets_ = Tensor([]).to(self.device) + target_lengths = [] + for t in targets: + # Remove padding symbol as it acts as the blank symbol. + t = t[t < 79] + targets_ = torch.cat([targets_, t]) + target_lengths.append(len(t)) + + targets = targets_.type(dtype=torch.long) + target_lengths = ( + torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) + ) + + return self._criterion(output, targets, input_lengths, target_lengths) + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: + """Predict on a single input.""" + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + log_probs = self.forward(image) + + raw_pred, _ = greedy_decoder( + predictions=log_probs, + character_mapper=self.mapper, + blank_label=79, + collapse_repeated=True, + ) + + log_probs, _ = log_probs.max(dim=2) + + predicted_characters = "".join(raw_pred[0]) + confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item() + + return predicted_characters, confidence_of_prediction diff --git a/text_recognizer/models/ctc_transformer_model.py b/text_recognizer/models/ctc_transformer_model.py new file mode 100644 index 0000000..25925f2 --- /dev/null +++ b/text_recognizer/models/ctc_transformer_model.py @@ -0,0 +1,120 @@ +"""Defines the CTC Transformer Model class.""" +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model +from text_recognizer.networks import greedy_decoder + + +class CTCTransformerModel(Model): + """Model for predicting a sequence of characters from an image of a text line.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + self.pad_token = dataset_args["args"]["pad_token"] + self.lower = dataset_args["args"]["lower"] + + if self._mapper is None: + self._mapper = EmnistMapper(pad_token=self.pad_token, lower=self.lower,) + + self.tensor_transform = ToTensor() + + def criterion(self, output: Tensor, targets: Tensor) -> Tensor: + """Computes the CTC loss. + + Args: + output (Tensor): Model predictions. + targets (Tensor): Correct output sequence. + + Returns: + Tensor: The CTC loss. + + """ + # Input lengths on the form [T, B] + input_lengths = torch.full( + size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, + ) + + # Configure target tensors for ctc loss. + targets_ = Tensor([]).to(self.device) + target_lengths = [] + for t in targets: + # Remove padding symbol as it acts as the blank symbol. + t = t[t < 53] + targets_ = torch.cat([targets_, t]) + target_lengths.append(len(t)) + + targets = targets_.type(dtype=torch.long) + target_lengths = ( + torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) + ) + + return self._criterion(output, targets, input_lengths, target_lengths) + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: + """Predict on a single input.""" + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + log_probs = self.forward(image) + + raw_pred, _ = greedy_decoder( + predictions=log_probs, + character_mapper=self.mapper, + blank_label=53, + collapse_repeated=True, + ) + + log_probs, _ = log_probs.max(dim=2) + + predicted_characters = "".join(raw_pred[0]) + confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item() + + return predicted_characters, confidence_of_prediction diff --git a/text_recognizer/models/segmentation_model.py b/text_recognizer/models/segmentation_model.py new file mode 100644 index 0000000..613108a --- /dev/null +++ b/text_recognizer/models/segmentation_model.py @@ -0,0 +1,75 @@ +"""Segmentation model for detecting and segmenting lines.""" +from typing import Callable, Dict, Optional, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.models.base import Model + + +class SegmentationModel(Model): + """Model for segmenting lines in an image.""" + + def __init__( + self, + network_fn: str, + dataset: str, + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + self.tensor_transform = ToTensor() + self.softmax = nn.Softmax(dim=2) + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor: + """Predict on a single input.""" + self.eval() + + if image.dtype is np.uint8: + # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype is torch.uint8 or image.dtype is torch.int64: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + if not torch.is_tensor(image): + image = Tensor(image) + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + + logits = self.forward(image) + + segmentation_mask = torch.argmax(logits, dim=1) + + return segmentation_mask diff --git a/text_recognizer/models/transformer_model.py b/text_recognizer/models/transformer_model.py new file mode 100644 index 0000000..3f63053 --- /dev/null +++ b/text_recognizer/models/transformer_model.py @@ -0,0 +1,124 @@ +"""Defines the CNN-Transformer class.""" +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset + +from text_recognizer.datasets import EmnistMapper +import text_recognizer.datasets.transforms as transforms +from text_recognizer.models.base import Model +from text_recognizer.networks import greedy_decoder + + +class TransformerModel(Model): + """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer.""" + + def __init__( + self, + network_fn: str, + dataset: str, + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + self.init_token = dataset_args["args"]["init_token"] + self.pad_token = dataset_args["args"]["pad_token"] + self.eos_token = dataset_args["args"]["eos_token"] + self.lower = dataset_args["args"]["lower"] + self.max_len = 100 + + if self._mapper is None: + self._mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + lower=self.lower, + ) + self.tensor_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])] + ) + self.softmax = nn.Softmax(dim=2) + + @torch.no_grad() + def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: + src = self.network.extract_image_features(image) + + # Added for vqvae transformer. + if isinstance(src, Tuple): + src = src[0] + + memory = self.network.encoder(src) + + confidence_of_predictions = [] + trg_indices = [self.mapper(self.init_token)] + + for _ in range(self.max_len - 1): + trg = torch.tensor(trg_indices, device=self.device)[None, :].long() + trg = self.network.target_embedding(trg) + logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None) + + # Convert logits to probabilities. + probs = self.softmax(logits) + + pred_token = probs.argmax(2)[:, -1].item() + confidence = probs.max(2).values[:, -1].item() + + trg_indices.append(pred_token) + confidence_of_predictions.append(confidence) + + if pred_token == self.mapper(self.eos_token): + break + + confidence = np.min(confidence_of_predictions) + predicted_characters = "".join([self.mapper(x) for x in trg_indices[1:]]) + + return predicted_characters, confidence + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: + """Predict on a single input.""" + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + + (predicted_characters, confidence_of_prediction,) = self._generate_sentence( + image + ) + + return predicted_characters, confidence_of_prediction diff --git a/text_recognizer/models/vqvae_model.py b/text_recognizer/models/vqvae_model.py new file mode 100644 index 0000000..70f6f1f --- /dev/null +++ b/text_recognizer/models/vqvae_model.py @@ -0,0 +1,80 @@ +"""Defines the VQVAEModel class.""" +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model + + +class VQVAEModel(Model): + """Model for reconstructing images from codebook.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + """Initializes the CharacterModel.""" + + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + self.pad_token = dataset_args["args"]["pad_token"] + if self._mapper is None: + self._mapper = EmnistMapper(pad_token=self.pad_token,) + self.tensor_transform = ToTensor() + self.softmax = nn.Softmax(dim=0) + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + """Reconstruction of image. + + Args: + image (Union[np.ndarray, torch.Tensor]): An image containing a character. + + Returns: + Tuple[str, float]: The predicted character and the confidence in the prediction. + + """ + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + image_reconstructed, _ = self.forward(image) + + return image_reconstructed -- cgit v1.2.3-70-g09d2