diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
commit | 7e8e54e84c63171e748bbf09516fd517e6821ace (patch) | |
tree | 996093f75a5d488dddf7ea1f159ed343a561ef89 /src/text_recognizer/models | |
parent | b0719d84138b6bbe5f04a4982dfca673aea1a368 (diff) |
Inital commit for refactoring to lightning
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r-- | src/text_recognizer/models/__init__.py | 18 | ||||
-rw-r--r-- | src/text_recognizer/models/base.py | 455 | ||||
-rw-r--r-- | src/text_recognizer/models/character_model.py | 88 | ||||
-rw-r--r-- | src/text_recognizer/models/crnn_model.py | 119 | ||||
-rw-r--r-- | src/text_recognizer/models/ctc_transformer_model.py | 120 | ||||
-rw-r--r-- | src/text_recognizer/models/segmentation_model.py | 75 | ||||
-rw-r--r-- | src/text_recognizer/models/transformer_model.py | 124 | ||||
-rw-r--r-- | src/text_recognizer/models/vqvae_model.py | 80 |
8 files changed, 0 insertions, 1079 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py deleted file mode 100644 index 7647d7e..0000000 --- a/src/text_recognizer/models/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Model modules.""" -from .base import Model -from .character_model import CharacterModel -from .crnn_model import CRNNModel -from .ctc_transformer_model import CTCTransformerModel -from .segmentation_model import SegmentationModel -from .transformer_model import TransformerModel -from .vqvae_model import VQVAEModel - -__all__ = [ - "CharacterModel", - "CRNNModel", - "CTCTransformerModel", - "Model", - "SegmentationModel", - "TransformerModel", - "VQVAEModel", -] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py deleted file mode 100644 index 70f4cdb..0000000 --- a/src/text_recognizer/models/base.py +++ /dev/null @@ -1,455 +0,0 @@ -"""Abstract Model class for PyTorch neural networks.""" - -from abc import ABC, abstractmethod -from glob import glob -import importlib -from pathlib import Path -import re -import shutil -from typing import Callable, Dict, List, Optional, Tuple, Type, Union - -from loguru import logger -import torch -from torch import nn -from torch import Tensor -from torch.optim.swa_utils import AveragedModel, SWALR -from torch.utils.data import DataLoader, Dataset, random_split -from torchsummary import summary - -from text_recognizer import datasets -from text_recognizer import networks -from text_recognizer.datasets import EmnistMapper - -WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" - - -class Model(ABC): - """Abstract Model class with composition of different parts defining a PyTorch neural network.""" - - def __init__( - self, - network_fn: str, - dataset: str, - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - """Base class, to be inherited by model for specific type of data. - - Args: - network_fn (str): The name of network. - dataset (str): The name dataset class. - network_args (Optional[Dict]): Arguments for the network. Defaults to None. - dataset_args (Optional[Dict]): Arguments for the dataset. - metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. - criterion (Optional[Callable]): The criterion to evaluate the performance of the network. - Defaults to None. - criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None. - optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None. - optimizer_args (Optional[Dict]): Dict of arguments for optimizer. Defaults to None. - lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None. - lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to - None. - swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to - None. - device (Optional[str]): Name of the device to train on. Defaults to None. - - """ - self._name = f"{self.__class__.__name__}_{dataset}_{network_fn}" - # Has to be set in subclass. - self._mapper = None - - # Placeholder. - self._input_shape = None - - self.dataset_name = dataset - self.dataset = None - self.dataset_args = dataset_args - - # Placeholders for datasets. - self.train_dataset = None - self.val_dataset = None - self.test_dataset = None - - # Stochastic Weight Averaging placeholders. - self.swa_args = swa_args - self._swa_scheduler = None - self._swa_network = None - self._use_swa_model = False - - # Experiment directory. - self.model_dir = None - - # Flag for configured model. - self.is_configured = False - self.data_prepared = False - - # Flag for stopping training. - self.stop_training = False - - self._metrics = metrics if metrics is not None else None - - # Set the device. - self._device = ( - torch.device("cuda" if torch.cuda.is_available() else "cpu") - if device is None - else device - ) - - # Configure network. - self._network = None - self._network_args = network_args - self._configure_network(network_fn) - - # Place network on device (GPU). - self.to_device() - - # Loss and Optimizer placeholders for before loading. - self._criterion = criterion - self.criterion_args = criterion_args - - self._optimizer = optimizer - self.optimizer_args = optimizer_args - - self._lr_scheduler = lr_scheduler - self.lr_scheduler_args = lr_scheduler_args - - def configure_model(self) -> None: - """Configures criterion and optimizers.""" - if not self.is_configured: - self._configure_criterion() - self._configure_optimizers() - - # Set this flag to true to prevent the model from configuring again. - self.is_configured = True - - def prepare_data(self) -> None: - """Prepare data for training.""" - # TODO add downloading. - if not self.data_prepared: - # Load dataset module. - self.dataset = getattr(datasets, self.dataset_name) - - # Load train dataset. - train_dataset = self.dataset(train=True, **self.dataset_args["args"]) - train_dataset.load_or_generate_data() - - # Set input shape. - self._input_shape = train_dataset.input_shape - - # Split train dataset into a training and validation partition. - dataset_len = len(train_dataset) - train_len = int( - self.dataset_args["train_args"]["train_fraction"] * dataset_len - ) - val_len = dataset_len - train_len - self.train_dataset, self.val_dataset = random_split( - train_dataset, lengths=[train_len, val_len] - ) - - # Load test dataset. - self.test_dataset = self.dataset(train=False, **self.dataset_args["args"]) - self.test_dataset.load_or_generate_data() - - # Set the flag to true to disable ability to load data again. - self.data_prepared = True - - def train_dataloader(self) -> DataLoader: - """Returns data loader for training set.""" - return DataLoader( - self.train_dataset, - batch_size=self.dataset_args["train_args"]["batch_size"], - num_workers=self.dataset_args["train_args"]["num_workers"], - shuffle=True, - pin_memory=True, - ) - - def val_dataloader(self) -> DataLoader: - """Returns data loader for validation set.""" - return DataLoader( - self.val_dataset, - batch_size=self.dataset_args["train_args"]["batch_size"], - num_workers=self.dataset_args["train_args"]["num_workers"], - shuffle=True, - pin_memory=True, - ) - - def test_dataloader(self) -> DataLoader: - """Returns data loader for test set.""" - return DataLoader( - self.test_dataset, - batch_size=self.dataset_args["train_args"]["batch_size"], - num_workers=self.dataset_args["train_args"]["num_workers"], - shuffle=False, - pin_memory=True, - ) - - def _configure_network(self, network_fn: Type[nn.Module]) -> None: - """Loads the network.""" - # If no network arguments are given, load pretrained weights if they exist. - # Load network module. - network_fn = getattr(networks, network_fn) - if self._network_args is None: - self.load_weights(network_fn) - else: - self._network = network_fn(**self._network_args) - - def _configure_criterion(self) -> None: - """Loads the criterion.""" - self._criterion = ( - self._criterion(**self.criterion_args) - if self._criterion is not None - else None - ) - - def _configure_optimizers(self,) -> None: - """Loads the optimizers.""" - if self._optimizer is not None: - self._optimizer = self._optimizer( - self._network.parameters(), **self.optimizer_args - ) - else: - self._optimizer = None - - if self._optimizer and self._lr_scheduler is not None: - if "steps_per_epoch" in self.lr_scheduler_args: - self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader()) - - # Assume lr scheduler should update at each epoch if not specified. - if "interval" not in self.lr_scheduler_args: - interval = "epoch" - else: - interval = self.lr_scheduler_args.pop("interval") - self._lr_scheduler = { - "lr_scheduler": self._lr_scheduler( - self._optimizer, **self.lr_scheduler_args - ), - "interval": interval, - } - - if self.swa_args is not None: - self._swa_scheduler = { - "swa_scheduler": SWALR(self._optimizer, swa_lr=self.swa_args["lr"]), - "swa_start": self.swa_args["start"], - } - self._swa_network = AveragedModel(self._network).to(self.device) - - @property - def name(self) -> str: - """Returns the name of the model.""" - return self._name - - @property - def input_shape(self) -> Tuple[int, ...]: - """The input shape.""" - return self._input_shape - - @property - def mapper(self) -> EmnistMapper: - """Returns the mapper that maps between ints and chars.""" - return self._mapper - - @property - def mapping(self) -> Dict: - """Returns the mapping between network output and Emnist character.""" - return self._mapper.mapping if self._mapper is not None else None - - def eval(self) -> None: - """Sets the network to evaluation mode.""" - self._network.eval() - - def train(self) -> None: - """Sets the network to train mode.""" - self._network.train() - - @property - def device(self) -> str: - """Device where the weights are stored, i.e. cpu or cuda.""" - return self._device - - @property - def metrics(self) -> Optional[Dict]: - """Metrics.""" - return self._metrics - - @property - def criterion(self) -> Optional[Callable]: - """Criterion.""" - return self._criterion - - @property - def optimizer(self) -> Optional[Callable]: - """Optimizer.""" - return self._optimizer - - @property - def lr_scheduler(self) -> Optional[Dict]: - """Returns a directory with the learning rate scheduler.""" - return self._lr_scheduler - - @property - def swa_scheduler(self) -> Optional[Dict]: - """Returns a directory with the stochastic weight averaging scheduler.""" - return self._swa_scheduler - - @property - def swa_network(self) -> Optional[Callable]: - """Returns the stochastic weight averaging network.""" - return self._swa_network - - @property - def network(self) -> Type[nn.Module]: - """Neural network.""" - # Returns the SWA network if available. - return self._network - - @property - def weights_filename(self) -> str: - """Filepath to the network weights.""" - WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True) - return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt") - - def use_swa_model(self) -> None: - """Set to use predictions from SWA model.""" - if self.swa_network is not None: - self._use_swa_model = True - - def forward(self, x: Tensor) -> Tensor: - """Feedforward pass with the network.""" - if self._use_swa_model: - return self.swa_network(x) - else: - return self.network(x) - - def summary( - self, - input_shape: Optional[Union[List, Tuple]] = None, - depth: int = 3, - device: Optional[str] = None, - ) -> None: - """Prints a summary of the network architecture.""" - device = self.device if device is None else device - - if input_shape is not None: - summary(self.network, input_shape, depth=depth, device=device) - elif self._input_shape is not None: - input_shape = tuple(self._input_shape) - summary(self.network, input_shape, depth=depth, device=device) - else: - logger.warning("Could not print summary as input shape is not set.") - - def to_device(self) -> None: - """Places the network on the device (GPU).""" - self._network.to(self._device) - - def _get_state_dict(self) -> Dict: - """Get the state dict of the model.""" - state = {"model_state": self._network.state_dict()} - - if self._optimizer is not None: - state["optimizer_state"] = self._optimizer.state_dict() - - if self._lr_scheduler is not None: - state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict() - state["scheduler_interval"] = self._lr_scheduler["interval"] - - if self._swa_network is not None: - state["swa_network"] = self._swa_network.state_dict() - - return state - - def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None: - """Load a previously saved checkpoint. - - Args: - checkpoint_path (Path): Path to the experiment with the checkpoint. - - """ - checkpoint_path = Path(checkpoint_path) - self.prepare_data() - self.configure_model() - logger.debug("Loading checkpoint...") - if not checkpoint_path.exists(): - logger.debug("File does not exist {str(checkpoint_path)}") - - checkpoint = torch.load(str(checkpoint_path), map_location=self.device) - self._network.load_state_dict(checkpoint["model_state"]) - - if self._optimizer is not None: - self._optimizer.load_state_dict(checkpoint["optimizer_state"]) - - if self._lr_scheduler is not None: - # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs - # with OneCycleLR. - if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR": - self._lr_scheduler["lr_scheduler"].load_state_dict( - checkpoint["scheduler_state"] - ) - self._lr_scheduler["interval"] = checkpoint["scheduler_interval"] - - if self._swa_network is not None: - self._swa_network.load_state_dict(checkpoint["swa_network"]) - - def save_checkpoint( - self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str - ) -> None: - """Saves a checkpoint of the model. - - Args: - checkpoint_path (Path): Path to the experiment with the checkpoint. - is_best (bool): If it is the currently best model. - epoch (int): The epoch of the checkpoint. - val_metric (str): Validation metric. - - """ - state = self._get_state_dict() - state["is_best"] = is_best - state["epoch"] = epoch - state["network_args"] = self._network_args - - checkpoint_path.mkdir(parents=True, exist_ok=True) - - logger.debug("Saving checkpoint...") - filepath = str(checkpoint_path / "last.pt") - torch.save(state, filepath) - - if is_best: - logger.debug( - f"Found a new best {val_metric}. Saving best checkpoint and weights." - ) - shutil.copyfile(filepath, str(checkpoint_path / "best.pt")) - - def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None: - """Load the network weights.""" - logger.debug("Loading network with pretrained weights.") - filename = glob(self.weights_filename)[0] - if not filename: - raise FileNotFoundError( - f"Could not find any pretrained weights at {self.weights_filename}" - ) - # Loading state directory. - state_dict = torch.load(filename, map_location=torch.device(self._device)) - self._network_args = state_dict["network_args"] - weights = state_dict["model_state"] - - # Initializes the network with trained weights. - if network_fn is not None: - self._network = network_fn(**self._network_args) - self._network.load_state_dict(weights) - - if "swa_network" in state_dict: - self._swa_network = AveragedModel(self._network).to(self.device) - self._swa_network.load_state_dict(state_dict["swa_network"]) - - def save_weights(self, path: Path) -> None: - """Save the network weights.""" - logger.debug("Saving the best network weights.") - shutil.copyfile(str(path / "best.pt"), self.weights_filename) diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py deleted file mode 100644 index f9944f3..0000000 --- a/src/text_recognizer/models/character_model.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Defines the CharacterModel class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model - - -class CharacterModel(Model): - """Model for predicting characters from images.""" - - def __init__( - self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - """Initializes the CharacterModel.""" - - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - self.pad_token = dataset_args["args"]["pad_token"] - if self._mapper is None: - self._mapper = EmnistMapper(pad_token=self.pad_token,) - self.tensor_transform = ToTensor() - self.softmax = nn.Softmax(dim=0) - - @torch.no_grad() - def predict_on_image( - self, image: Union[np.ndarray, torch.Tensor] - ) -> Tuple[str, float]: - """Character prediction on an image. - - Args: - image (Union[np.ndarray, torch.Tensor]): An image containing a character. - - Returns: - Tuple[str, float]: The predicted character and the confidence in the prediction. - - """ - self.eval() - - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - logits = self.forward(image) - - prediction = self.softmax(logits.squeeze(0)) - - index = int(torch.argmax(prediction, dim=0)) - confidence_of_prediction = prediction[index] - predicted_character = self.mapper(index) - - return predicted_character, confidence_of_prediction diff --git a/src/text_recognizer/models/crnn_model.py b/src/text_recognizer/models/crnn_model.py deleted file mode 100644 index 1e01a83..0000000 --- a/src/text_recognizer/models/crnn_model.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Defines the CRNNModel class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model -from text_recognizer.networks import greedy_decoder - - -class CRNNModel(Model): - """Model for predicting a sequence of characters from an image of a text line.""" - - def __init__( - self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - - self.pad_token = dataset_args["args"]["pad_token"] - if self._mapper is None: - self._mapper = EmnistMapper(pad_token=self.pad_token,) - self.tensor_transform = ToTensor() - - def criterion(self, output: Tensor, targets: Tensor) -> Tensor: - """Computes the CTC loss. - - Args: - output (Tensor): Model predictions. - targets (Tensor): Correct output sequence. - - Returns: - Tensor: The CTC loss. - - """ - - # Input lengths on the form [T, B] - input_lengths = torch.full( - size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, - ) - - # Configure target tensors for ctc loss. - targets_ = Tensor([]).to(self.device) - target_lengths = [] - for t in targets: - # Remove padding symbol as it acts as the blank symbol. - t = t[t < 79] - targets_ = torch.cat([targets_, t]) - target_lengths.append(len(t)) - - targets = targets_.type(dtype=torch.long) - target_lengths = ( - torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) - ) - - return self._criterion(output, targets, input_lengths, target_lengths) - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: - """Predict on a single input.""" - self.eval() - - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - - # Rescale image between 0 and 1. - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - log_probs = self.forward(image) - - raw_pred, _ = greedy_decoder( - predictions=log_probs, - character_mapper=self.mapper, - blank_label=79, - collapse_repeated=True, - ) - - log_probs, _ = log_probs.max(dim=2) - - predicted_characters = "".join(raw_pred[0]) - confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item() - - return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/ctc_transformer_model.py b/src/text_recognizer/models/ctc_transformer_model.py deleted file mode 100644 index 25925f2..0000000 --- a/src/text_recognizer/models/ctc_transformer_model.py +++ /dev/null @@ -1,120 +0,0 @@ -"""Defines the CTC Transformer Model class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model -from text_recognizer.networks import greedy_decoder - - -class CTCTransformerModel(Model): - """Model for predicting a sequence of characters from an image of a text line.""" - - def __init__( - self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - self.pad_token = dataset_args["args"]["pad_token"] - self.lower = dataset_args["args"]["lower"] - - if self._mapper is None: - self._mapper = EmnistMapper(pad_token=self.pad_token, lower=self.lower,) - - self.tensor_transform = ToTensor() - - def criterion(self, output: Tensor, targets: Tensor) -> Tensor: - """Computes the CTC loss. - - Args: - output (Tensor): Model predictions. - targets (Tensor): Correct output sequence. - - Returns: - Tensor: The CTC loss. - - """ - # Input lengths on the form [T, B] - input_lengths = torch.full( - size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, - ) - - # Configure target tensors for ctc loss. - targets_ = Tensor([]).to(self.device) - target_lengths = [] - for t in targets: - # Remove padding symbol as it acts as the blank symbol. - t = t[t < 53] - targets_ = torch.cat([targets_, t]) - target_lengths.append(len(t)) - - targets = targets_.type(dtype=torch.long) - target_lengths = ( - torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) - ) - - return self._criterion(output, targets, input_lengths, target_lengths) - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: - """Predict on a single input.""" - self.eval() - - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - - # Rescale image between 0 and 1. - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - log_probs = self.forward(image) - - raw_pred, _ = greedy_decoder( - predictions=log_probs, - character_mapper=self.mapper, - blank_label=53, - collapse_repeated=True, - ) - - log_probs, _ = log_probs.max(dim=2) - - predicted_characters = "".join(raw_pred[0]) - confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item() - - return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/segmentation_model.py b/src/text_recognizer/models/segmentation_model.py deleted file mode 100644 index 613108a..0000000 --- a/src/text_recognizer/models/segmentation_model.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Segmentation model for detecting and segmenting lines.""" -from typing import Callable, Dict, Optional, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.models.base import Model - - -class SegmentationModel(Model): - """Model for segmenting lines in an image.""" - - def __init__( - self, - network_fn: str, - dataset: str, - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - self.tensor_transform = ToTensor() - self.softmax = nn.Softmax(dim=2) - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor: - """Predict on a single input.""" - self.eval() - - if image.dtype is np.uint8: - # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - - # Rescale image between 0 and 1. - if image.dtype is torch.uint8 or image.dtype is torch.int64: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - if not torch.is_tensor(image): - image = Tensor(image) - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - - logits = self.forward(image) - - segmentation_mask = torch.argmax(logits, dim=1) - - return segmentation_mask diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py deleted file mode 100644 index 3f63053..0000000 --- a/src/text_recognizer/models/transformer_model.py +++ /dev/null @@ -1,124 +0,0 @@ -"""Defines the CNN-Transformer class.""" -from typing import Callable, Dict, List, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset - -from text_recognizer.datasets import EmnistMapper -import text_recognizer.datasets.transforms as transforms -from text_recognizer.models.base import Model -from text_recognizer.networks import greedy_decoder - - -class TransformerModel(Model): - """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer.""" - - def __init__( - self, - network_fn: str, - dataset: str, - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - self.init_token = dataset_args["args"]["init_token"] - self.pad_token = dataset_args["args"]["pad_token"] - self.eos_token = dataset_args["args"]["eos_token"] - self.lower = dataset_args["args"]["lower"] - self.max_len = 100 - - if self._mapper is None: - self._mapper = EmnistMapper( - init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - lower=self.lower, - ) - self.tensor_transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])] - ) - self.softmax = nn.Softmax(dim=2) - - @torch.no_grad() - def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: - src = self.network.extract_image_features(image) - - # Added for vqvae transformer. - if isinstance(src, Tuple): - src = src[0] - - memory = self.network.encoder(src) - - confidence_of_predictions = [] - trg_indices = [self.mapper(self.init_token)] - - for _ in range(self.max_len - 1): - trg = torch.tensor(trg_indices, device=self.device)[None, :].long() - trg = self.network.target_embedding(trg) - logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None) - - # Convert logits to probabilities. - probs = self.softmax(logits) - - pred_token = probs.argmax(2)[:, -1].item() - confidence = probs.max(2).values[:, -1].item() - - trg_indices.append(pred_token) - confidence_of_predictions.append(confidence) - - if pred_token == self.mapper(self.eos_token): - break - - confidence = np.min(confidence_of_predictions) - predicted_characters = "".join([self.mapper(x) for x in trg_indices[1:]]) - - return predicted_characters, confidence - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: - """Predict on a single input.""" - self.eval() - - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - - # Rescale image between 0 and 1. - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - - (predicted_characters, confidence_of_prediction,) = self._generate_sentence( - image - ) - - return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/vqvae_model.py b/src/text_recognizer/models/vqvae_model.py deleted file mode 100644 index 70f6f1f..0000000 --- a/src/text_recognizer/models/vqvae_model.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Defines the VQVAEModel class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model - - -class VQVAEModel(Model): - """Model for reconstructing images from codebook.""" - - def __init__( - self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - """Initializes the CharacterModel.""" - - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - self.pad_token = dataset_args["args"]["pad_token"] - if self._mapper is None: - self._mapper = EmnistMapper(pad_token=self.pad_token,) - self.tensor_transform = ToTensor() - self.softmax = nn.Softmax(dim=0) - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: - """Reconstruction of image. - - Args: - image (Union[np.ndarray, torch.Tensor]): An image containing a character. - - Returns: - Tuple[str, float]: The predicted character and the confidence in the prediction. - - """ - self.eval() - - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - image_reconstructed, _ = self.forward(image) - - return image_reconstructed |