summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /src/text_recognizer/models
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r--src/text_recognizer/models/__init__.py18
-rw-r--r--src/text_recognizer/models/base.py455
-rw-r--r--src/text_recognizer/models/character_model.py88
-rw-r--r--src/text_recognizer/models/crnn_model.py119
-rw-r--r--src/text_recognizer/models/ctc_transformer_model.py120
-rw-r--r--src/text_recognizer/models/segmentation_model.py75
-rw-r--r--src/text_recognizer/models/transformer_model.py124
-rw-r--r--src/text_recognizer/models/vqvae_model.py80
8 files changed, 0 insertions, 1079 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
deleted file mode 100644
index 7647d7e..0000000
--- a/src/text_recognizer/models/__init__.py
+++ /dev/null
@@ -1,18 +0,0 @@
-"""Model modules."""
-from .base import Model
-from .character_model import CharacterModel
-from .crnn_model import CRNNModel
-from .ctc_transformer_model import CTCTransformerModel
-from .segmentation_model import SegmentationModel
-from .transformer_model import TransformerModel
-from .vqvae_model import VQVAEModel
-
-__all__ = [
- "CharacterModel",
- "CRNNModel",
- "CTCTransformerModel",
- "Model",
- "SegmentationModel",
- "TransformerModel",
- "VQVAEModel",
-]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
deleted file mode 100644
index 70f4cdb..0000000
--- a/src/text_recognizer/models/base.py
+++ /dev/null
@@ -1,455 +0,0 @@
-"""Abstract Model class for PyTorch neural networks."""
-
-from abc import ABC, abstractmethod
-from glob import glob
-import importlib
-from pathlib import Path
-import re
-import shutil
-from typing import Callable, Dict, List, Optional, Tuple, Type, Union
-
-from loguru import logger
-import torch
-from torch import nn
-from torch import Tensor
-from torch.optim.swa_utils import AveragedModel, SWALR
-from torch.utils.data import DataLoader, Dataset, random_split
-from torchsummary import summary
-
-from text_recognizer import datasets
-from text_recognizer import networks
-from text_recognizer.datasets import EmnistMapper
-
-WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
-
-
-class Model(ABC):
- """Abstract Model class with composition of different parts defining a PyTorch neural network."""
-
- def __init__(
- self,
- network_fn: str,
- dataset: str,
- network_args: Optional[Dict] = None,
- dataset_args: Optional[Dict] = None,
- metrics: Optional[Dict] = None,
- criterion: Optional[Callable] = None,
- criterion_args: Optional[Dict] = None,
- optimizer: Optional[Callable] = None,
- optimizer_args: Optional[Dict] = None,
- lr_scheduler: Optional[Callable] = None,
- lr_scheduler_args: Optional[Dict] = None,
- swa_args: Optional[Dict] = None,
- device: Optional[str] = None,
- ) -> None:
- """Base class, to be inherited by model for specific type of data.
-
- Args:
- network_fn (str): The name of network.
- dataset (str): The name dataset class.
- network_args (Optional[Dict]): Arguments for the network. Defaults to None.
- dataset_args (Optional[Dict]): Arguments for the dataset.
- metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
- criterion (Optional[Callable]): The criterion to evaluate the performance of the network.
- Defaults to None.
- criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None.
- optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None.
- optimizer_args (Optional[Dict]): Dict of arguments for optimizer. Defaults to None.
- lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None.
- lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to
- None.
- swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to
- None.
- device (Optional[str]): Name of the device to train on. Defaults to None.
-
- """
- self._name = f"{self.__class__.__name__}_{dataset}_{network_fn}"
- # Has to be set in subclass.
- self._mapper = None
-
- # Placeholder.
- self._input_shape = None
-
- self.dataset_name = dataset
- self.dataset = None
- self.dataset_args = dataset_args
-
- # Placeholders for datasets.
- self.train_dataset = None
- self.val_dataset = None
- self.test_dataset = None
-
- # Stochastic Weight Averaging placeholders.
- self.swa_args = swa_args
- self._swa_scheduler = None
- self._swa_network = None
- self._use_swa_model = False
-
- # Experiment directory.
- self.model_dir = None
-
- # Flag for configured model.
- self.is_configured = False
- self.data_prepared = False
-
- # Flag for stopping training.
- self.stop_training = False
-
- self._metrics = metrics if metrics is not None else None
-
- # Set the device.
- self._device = (
- torch.device("cuda" if torch.cuda.is_available() else "cpu")
- if device is None
- else device
- )
-
- # Configure network.
- self._network = None
- self._network_args = network_args
- self._configure_network(network_fn)
-
- # Place network on device (GPU).
- self.to_device()
-
- # Loss and Optimizer placeholders for before loading.
- self._criterion = criterion
- self.criterion_args = criterion_args
-
- self._optimizer = optimizer
- self.optimizer_args = optimizer_args
-
- self._lr_scheduler = lr_scheduler
- self.lr_scheduler_args = lr_scheduler_args
-
- def configure_model(self) -> None:
- """Configures criterion and optimizers."""
- if not self.is_configured:
- self._configure_criterion()
- self._configure_optimizers()
-
- # Set this flag to true to prevent the model from configuring again.
- self.is_configured = True
-
- def prepare_data(self) -> None:
- """Prepare data for training."""
- # TODO add downloading.
- if not self.data_prepared:
- # Load dataset module.
- self.dataset = getattr(datasets, self.dataset_name)
-
- # Load train dataset.
- train_dataset = self.dataset(train=True, **self.dataset_args["args"])
- train_dataset.load_or_generate_data()
-
- # Set input shape.
- self._input_shape = train_dataset.input_shape
-
- # Split train dataset into a training and validation partition.
- dataset_len = len(train_dataset)
- train_len = int(
- self.dataset_args["train_args"]["train_fraction"] * dataset_len
- )
- val_len = dataset_len - train_len
- self.train_dataset, self.val_dataset = random_split(
- train_dataset, lengths=[train_len, val_len]
- )
-
- # Load test dataset.
- self.test_dataset = self.dataset(train=False, **self.dataset_args["args"])
- self.test_dataset.load_or_generate_data()
-
- # Set the flag to true to disable ability to load data again.
- self.data_prepared = True
-
- def train_dataloader(self) -> DataLoader:
- """Returns data loader for training set."""
- return DataLoader(
- self.train_dataset,
- batch_size=self.dataset_args["train_args"]["batch_size"],
- num_workers=self.dataset_args["train_args"]["num_workers"],
- shuffle=True,
- pin_memory=True,
- )
-
- def val_dataloader(self) -> DataLoader:
- """Returns data loader for validation set."""
- return DataLoader(
- self.val_dataset,
- batch_size=self.dataset_args["train_args"]["batch_size"],
- num_workers=self.dataset_args["train_args"]["num_workers"],
- shuffle=True,
- pin_memory=True,
- )
-
- def test_dataloader(self) -> DataLoader:
- """Returns data loader for test set."""
- return DataLoader(
- self.test_dataset,
- batch_size=self.dataset_args["train_args"]["batch_size"],
- num_workers=self.dataset_args["train_args"]["num_workers"],
- shuffle=False,
- pin_memory=True,
- )
-
- def _configure_network(self, network_fn: Type[nn.Module]) -> None:
- """Loads the network."""
- # If no network arguments are given, load pretrained weights if they exist.
- # Load network module.
- network_fn = getattr(networks, network_fn)
- if self._network_args is None:
- self.load_weights(network_fn)
- else:
- self._network = network_fn(**self._network_args)
-
- def _configure_criterion(self) -> None:
- """Loads the criterion."""
- self._criterion = (
- self._criterion(**self.criterion_args)
- if self._criterion is not None
- else None
- )
-
- def _configure_optimizers(self,) -> None:
- """Loads the optimizers."""
- if self._optimizer is not None:
- self._optimizer = self._optimizer(
- self._network.parameters(), **self.optimizer_args
- )
- else:
- self._optimizer = None
-
- if self._optimizer and self._lr_scheduler is not None:
- if "steps_per_epoch" in self.lr_scheduler_args:
- self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader())
-
- # Assume lr scheduler should update at each epoch if not specified.
- if "interval" not in self.lr_scheduler_args:
- interval = "epoch"
- else:
- interval = self.lr_scheduler_args.pop("interval")
- self._lr_scheduler = {
- "lr_scheduler": self._lr_scheduler(
- self._optimizer, **self.lr_scheduler_args
- ),
- "interval": interval,
- }
-
- if self.swa_args is not None:
- self._swa_scheduler = {
- "swa_scheduler": SWALR(self._optimizer, swa_lr=self.swa_args["lr"]),
- "swa_start": self.swa_args["start"],
- }
- self._swa_network = AveragedModel(self._network).to(self.device)
-
- @property
- def name(self) -> str:
- """Returns the name of the model."""
- return self._name
-
- @property
- def input_shape(self) -> Tuple[int, ...]:
- """The input shape."""
- return self._input_shape
-
- @property
- def mapper(self) -> EmnistMapper:
- """Returns the mapper that maps between ints and chars."""
- return self._mapper
-
- @property
- def mapping(self) -> Dict:
- """Returns the mapping between network output and Emnist character."""
- return self._mapper.mapping if self._mapper is not None else None
-
- def eval(self) -> None:
- """Sets the network to evaluation mode."""
- self._network.eval()
-
- def train(self) -> None:
- """Sets the network to train mode."""
- self._network.train()
-
- @property
- def device(self) -> str:
- """Device where the weights are stored, i.e. cpu or cuda."""
- return self._device
-
- @property
- def metrics(self) -> Optional[Dict]:
- """Metrics."""
- return self._metrics
-
- @property
- def criterion(self) -> Optional[Callable]:
- """Criterion."""
- return self._criterion
-
- @property
- def optimizer(self) -> Optional[Callable]:
- """Optimizer."""
- return self._optimizer
-
- @property
- def lr_scheduler(self) -> Optional[Dict]:
- """Returns a directory with the learning rate scheduler."""
- return self._lr_scheduler
-
- @property
- def swa_scheduler(self) -> Optional[Dict]:
- """Returns a directory with the stochastic weight averaging scheduler."""
- return self._swa_scheduler
-
- @property
- def swa_network(self) -> Optional[Callable]:
- """Returns the stochastic weight averaging network."""
- return self._swa_network
-
- @property
- def network(self) -> Type[nn.Module]:
- """Neural network."""
- # Returns the SWA network if available.
- return self._network
-
- @property
- def weights_filename(self) -> str:
- """Filepath to the network weights."""
- WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True)
- return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt")
-
- def use_swa_model(self) -> None:
- """Set to use predictions from SWA model."""
- if self.swa_network is not None:
- self._use_swa_model = True
-
- def forward(self, x: Tensor) -> Tensor:
- """Feedforward pass with the network."""
- if self._use_swa_model:
- return self.swa_network(x)
- else:
- return self.network(x)
-
- def summary(
- self,
- input_shape: Optional[Union[List, Tuple]] = None,
- depth: int = 3,
- device: Optional[str] = None,
- ) -> None:
- """Prints a summary of the network architecture."""
- device = self.device if device is None else device
-
- if input_shape is not None:
- summary(self.network, input_shape, depth=depth, device=device)
- elif self._input_shape is not None:
- input_shape = tuple(self._input_shape)
- summary(self.network, input_shape, depth=depth, device=device)
- else:
- logger.warning("Could not print summary as input shape is not set.")
-
- def to_device(self) -> None:
- """Places the network on the device (GPU)."""
- self._network.to(self._device)
-
- def _get_state_dict(self) -> Dict:
- """Get the state dict of the model."""
- state = {"model_state": self._network.state_dict()}
-
- if self._optimizer is not None:
- state["optimizer_state"] = self._optimizer.state_dict()
-
- if self._lr_scheduler is not None:
- state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict()
- state["scheduler_interval"] = self._lr_scheduler["interval"]
-
- if self._swa_network is not None:
- state["swa_network"] = self._swa_network.state_dict()
-
- return state
-
- def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None:
- """Load a previously saved checkpoint.
-
- Args:
- checkpoint_path (Path): Path to the experiment with the checkpoint.
-
- """
- checkpoint_path = Path(checkpoint_path)
- self.prepare_data()
- self.configure_model()
- logger.debug("Loading checkpoint...")
- if not checkpoint_path.exists():
- logger.debug("File does not exist {str(checkpoint_path)}")
-
- checkpoint = torch.load(str(checkpoint_path), map_location=self.device)
- self._network.load_state_dict(checkpoint["model_state"])
-
- if self._optimizer is not None:
- self._optimizer.load_state_dict(checkpoint["optimizer_state"])
-
- if self._lr_scheduler is not None:
- # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs
- # with OneCycleLR.
- if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR":
- self._lr_scheduler["lr_scheduler"].load_state_dict(
- checkpoint["scheduler_state"]
- )
- self._lr_scheduler["interval"] = checkpoint["scheduler_interval"]
-
- if self._swa_network is not None:
- self._swa_network.load_state_dict(checkpoint["swa_network"])
-
- def save_checkpoint(
- self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str
- ) -> None:
- """Saves a checkpoint of the model.
-
- Args:
- checkpoint_path (Path): Path to the experiment with the checkpoint.
- is_best (bool): If it is the currently best model.
- epoch (int): The epoch of the checkpoint.
- val_metric (str): Validation metric.
-
- """
- state = self._get_state_dict()
- state["is_best"] = is_best
- state["epoch"] = epoch
- state["network_args"] = self._network_args
-
- checkpoint_path.mkdir(parents=True, exist_ok=True)
-
- logger.debug("Saving checkpoint...")
- filepath = str(checkpoint_path / "last.pt")
- torch.save(state, filepath)
-
- if is_best:
- logger.debug(
- f"Found a new best {val_metric}. Saving best checkpoint and weights."
- )
- shutil.copyfile(filepath, str(checkpoint_path / "best.pt"))
-
- def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None:
- """Load the network weights."""
- logger.debug("Loading network with pretrained weights.")
- filename = glob(self.weights_filename)[0]
- if not filename:
- raise FileNotFoundError(
- f"Could not find any pretrained weights at {self.weights_filename}"
- )
- # Loading state directory.
- state_dict = torch.load(filename, map_location=torch.device(self._device))
- self._network_args = state_dict["network_args"]
- weights = state_dict["model_state"]
-
- # Initializes the network with trained weights.
- if network_fn is not None:
- self._network = network_fn(**self._network_args)
- self._network.load_state_dict(weights)
-
- if "swa_network" in state_dict:
- self._swa_network = AveragedModel(self._network).to(self.device)
- self._swa_network.load_state_dict(state_dict["swa_network"])
-
- def save_weights(self, path: Path) -> None:
- """Save the network weights."""
- logger.debug("Saving the best network weights.")
- shutil.copyfile(str(path / "best.pt"), self.weights_filename)
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
deleted file mode 100644
index f9944f3..0000000
--- a/src/text_recognizer/models/character_model.py
+++ /dev/null
@@ -1,88 +0,0 @@
-"""Defines the CharacterModel class."""
-from typing import Callable, Dict, Optional, Tuple, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch.utils.data import Dataset
-from torchvision.transforms import ToTensor
-
-from text_recognizer.datasets import EmnistMapper
-from text_recognizer.models.base import Model
-
-
-class CharacterModel(Model):
- """Model for predicting characters from images."""
-
- def __init__(
- self,
- network_fn: Type[nn.Module],
- dataset: Type[Dataset],
- network_args: Optional[Dict] = None,
- dataset_args: Optional[Dict] = None,
- metrics: Optional[Dict] = None,
- criterion: Optional[Callable] = None,
- criterion_args: Optional[Dict] = None,
- optimizer: Optional[Callable] = None,
- optimizer_args: Optional[Dict] = None,
- lr_scheduler: Optional[Callable] = None,
- lr_scheduler_args: Optional[Dict] = None,
- swa_args: Optional[Dict] = None,
- device: Optional[str] = None,
- ) -> None:
- """Initializes the CharacterModel."""
-
- super().__init__(
- network_fn,
- dataset,
- network_args,
- dataset_args,
- metrics,
- criterion,
- criterion_args,
- optimizer,
- optimizer_args,
- lr_scheduler,
- lr_scheduler_args,
- swa_args,
- device,
- )
- self.pad_token = dataset_args["args"]["pad_token"]
- if self._mapper is None:
- self._mapper = EmnistMapper(pad_token=self.pad_token,)
- self.tensor_transform = ToTensor()
- self.softmax = nn.Softmax(dim=0)
-
- @torch.no_grad()
- def predict_on_image(
- self, image: Union[np.ndarray, torch.Tensor]
- ) -> Tuple[str, float]:
- """Character prediction on an image.
-
- Args:
- image (Union[np.ndarray, torch.Tensor]): An image containing a character.
-
- Returns:
- Tuple[str, float]: The predicted character and the confidence in the prediction.
-
- """
- self.eval()
-
- if image.dtype == np.uint8:
- # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
- image = self.tensor_transform(image)
- if image.dtype == torch.uint8:
- # If the image is an unscaled tensor.
- image = image.type("torch.FloatTensor") / 255
-
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
- logits = self.forward(image)
-
- prediction = self.softmax(logits.squeeze(0))
-
- index = int(torch.argmax(prediction, dim=0))
- confidence_of_prediction = prediction[index]
- predicted_character = self.mapper(index)
-
- return predicted_character, confidence_of_prediction
diff --git a/src/text_recognizer/models/crnn_model.py b/src/text_recognizer/models/crnn_model.py
deleted file mode 100644
index 1e01a83..0000000
--- a/src/text_recognizer/models/crnn_model.py
+++ /dev/null
@@ -1,119 +0,0 @@
-"""Defines the CRNNModel class."""
-from typing import Callable, Dict, Optional, Tuple, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-from torch.utils.data import Dataset
-from torchvision.transforms import ToTensor
-
-from text_recognizer.datasets import EmnistMapper
-from text_recognizer.models.base import Model
-from text_recognizer.networks import greedy_decoder
-
-
-class CRNNModel(Model):
- """Model for predicting a sequence of characters from an image of a text line."""
-
- def __init__(
- self,
- network_fn: Type[nn.Module],
- dataset: Type[Dataset],
- network_args: Optional[Dict] = None,
- dataset_args: Optional[Dict] = None,
- metrics: Optional[Dict] = None,
- criterion: Optional[Callable] = None,
- criterion_args: Optional[Dict] = None,
- optimizer: Optional[Callable] = None,
- optimizer_args: Optional[Dict] = None,
- lr_scheduler: Optional[Callable] = None,
- lr_scheduler_args: Optional[Dict] = None,
- swa_args: Optional[Dict] = None,
- device: Optional[str] = None,
- ) -> None:
- super().__init__(
- network_fn,
- dataset,
- network_args,
- dataset_args,
- metrics,
- criterion,
- criterion_args,
- optimizer,
- optimizer_args,
- lr_scheduler,
- lr_scheduler_args,
- swa_args,
- device,
- )
-
- self.pad_token = dataset_args["args"]["pad_token"]
- if self._mapper is None:
- self._mapper = EmnistMapper(pad_token=self.pad_token,)
- self.tensor_transform = ToTensor()
-
- def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
- """Computes the CTC loss.
-
- Args:
- output (Tensor): Model predictions.
- targets (Tensor): Correct output sequence.
-
- Returns:
- Tensor: The CTC loss.
-
- """
-
- # Input lengths on the form [T, B]
- input_lengths = torch.full(
- size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,
- )
-
- # Configure target tensors for ctc loss.
- targets_ = Tensor([]).to(self.device)
- target_lengths = []
- for t in targets:
- # Remove padding symbol as it acts as the blank symbol.
- t = t[t < 79]
- targets_ = torch.cat([targets_, t])
- target_lengths.append(len(t))
-
- targets = targets_.type(dtype=torch.long)
- target_lengths = (
- torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
- )
-
- return self._criterion(output, targets, input_lengths, target_lengths)
-
- @torch.no_grad()
- def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
- """Predict on a single input."""
- self.eval()
-
- if image.dtype == np.uint8:
- # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
- image = self.tensor_transform(image)
-
- # Rescale image between 0 and 1.
- if image.dtype == torch.uint8:
- # If the image is an unscaled tensor.
- image = image.type("torch.FloatTensor") / 255
-
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
- log_probs = self.forward(image)
-
- raw_pred, _ = greedy_decoder(
- predictions=log_probs,
- character_mapper=self.mapper,
- blank_label=79,
- collapse_repeated=True,
- )
-
- log_probs, _ = log_probs.max(dim=2)
-
- predicted_characters = "".join(raw_pred[0])
- confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item()
-
- return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/ctc_transformer_model.py b/src/text_recognizer/models/ctc_transformer_model.py
deleted file mode 100644
index 25925f2..0000000
--- a/src/text_recognizer/models/ctc_transformer_model.py
+++ /dev/null
@@ -1,120 +0,0 @@
-"""Defines the CTC Transformer Model class."""
-from typing import Callable, Dict, Optional, Tuple, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-from torch.utils.data import Dataset
-from torchvision.transforms import ToTensor
-
-from text_recognizer.datasets import EmnistMapper
-from text_recognizer.models.base import Model
-from text_recognizer.networks import greedy_decoder
-
-
-class CTCTransformerModel(Model):
- """Model for predicting a sequence of characters from an image of a text line."""
-
- def __init__(
- self,
- network_fn: Type[nn.Module],
- dataset: Type[Dataset],
- network_args: Optional[Dict] = None,
- dataset_args: Optional[Dict] = None,
- metrics: Optional[Dict] = None,
- criterion: Optional[Callable] = None,
- criterion_args: Optional[Dict] = None,
- optimizer: Optional[Callable] = None,
- optimizer_args: Optional[Dict] = None,
- lr_scheduler: Optional[Callable] = None,
- lr_scheduler_args: Optional[Dict] = None,
- swa_args: Optional[Dict] = None,
- device: Optional[str] = None,
- ) -> None:
- super().__init__(
- network_fn,
- dataset,
- network_args,
- dataset_args,
- metrics,
- criterion,
- criterion_args,
- optimizer,
- optimizer_args,
- lr_scheduler,
- lr_scheduler_args,
- swa_args,
- device,
- )
- self.pad_token = dataset_args["args"]["pad_token"]
- self.lower = dataset_args["args"]["lower"]
-
- if self._mapper is None:
- self._mapper = EmnistMapper(pad_token=self.pad_token, lower=self.lower,)
-
- self.tensor_transform = ToTensor()
-
- def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
- """Computes the CTC loss.
-
- Args:
- output (Tensor): Model predictions.
- targets (Tensor): Correct output sequence.
-
- Returns:
- Tensor: The CTC loss.
-
- """
- # Input lengths on the form [T, B]
- input_lengths = torch.full(
- size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,
- )
-
- # Configure target tensors for ctc loss.
- targets_ = Tensor([]).to(self.device)
- target_lengths = []
- for t in targets:
- # Remove padding symbol as it acts as the blank symbol.
- t = t[t < 53]
- targets_ = torch.cat([targets_, t])
- target_lengths.append(len(t))
-
- targets = targets_.type(dtype=torch.long)
- target_lengths = (
- torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
- )
-
- return self._criterion(output, targets, input_lengths, target_lengths)
-
- @torch.no_grad()
- def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
- """Predict on a single input."""
- self.eval()
-
- if image.dtype == np.uint8:
- # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
- image = self.tensor_transform(image)
-
- # Rescale image between 0 and 1.
- if image.dtype == torch.uint8:
- # If the image is an unscaled tensor.
- image = image.type("torch.FloatTensor") / 255
-
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
- log_probs = self.forward(image)
-
- raw_pred, _ = greedy_decoder(
- predictions=log_probs,
- character_mapper=self.mapper,
- blank_label=53,
- collapse_repeated=True,
- )
-
- log_probs, _ = log_probs.max(dim=2)
-
- predicted_characters = "".join(raw_pred[0])
- confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item()
-
- return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/segmentation_model.py b/src/text_recognizer/models/segmentation_model.py
deleted file mode 100644
index 613108a..0000000
--- a/src/text_recognizer/models/segmentation_model.py
+++ /dev/null
@@ -1,75 +0,0 @@
-"""Segmentation model for detecting and segmenting lines."""
-from typing import Callable, Dict, Optional, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-from torch.utils.data import Dataset
-from torchvision.transforms import ToTensor
-
-from text_recognizer.models.base import Model
-
-
-class SegmentationModel(Model):
- """Model for segmenting lines in an image."""
-
- def __init__(
- self,
- network_fn: str,
- dataset: str,
- network_args: Optional[Dict] = None,
- dataset_args: Optional[Dict] = None,
- metrics: Optional[Dict] = None,
- criterion: Optional[Callable] = None,
- criterion_args: Optional[Dict] = None,
- optimizer: Optional[Callable] = None,
- optimizer_args: Optional[Dict] = None,
- lr_scheduler: Optional[Callable] = None,
- lr_scheduler_args: Optional[Dict] = None,
- swa_args: Optional[Dict] = None,
- device: Optional[str] = None,
- ) -> None:
- super().__init__(
- network_fn,
- dataset,
- network_args,
- dataset_args,
- metrics,
- criterion,
- criterion_args,
- optimizer,
- optimizer_args,
- lr_scheduler,
- lr_scheduler_args,
- swa_args,
- device,
- )
- self.tensor_transform = ToTensor()
- self.softmax = nn.Softmax(dim=2)
-
- @torch.no_grad()
- def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor:
- """Predict on a single input."""
- self.eval()
-
- if image.dtype is np.uint8:
- # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1].
- image = self.tensor_transform(image)
-
- # Rescale image between 0 and 1.
- if image.dtype is torch.uint8 or image.dtype is torch.int64:
- # If the image is an unscaled tensor.
- image = image.type("torch.FloatTensor") / 255
-
- if not torch.is_tensor(image):
- image = Tensor(image)
-
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
-
- logits = self.forward(image)
-
- segmentation_mask = torch.argmax(logits, dim=1)
-
- return segmentation_mask
diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py
deleted file mode 100644
index 3f63053..0000000
--- a/src/text_recognizer/models/transformer_model.py
+++ /dev/null
@@ -1,124 +0,0 @@
-"""Defines the CNN-Transformer class."""
-from typing import Callable, Dict, List, Optional, Tuple, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-from torch.utils.data import Dataset
-
-from text_recognizer.datasets import EmnistMapper
-import text_recognizer.datasets.transforms as transforms
-from text_recognizer.models.base import Model
-from text_recognizer.networks import greedy_decoder
-
-
-class TransformerModel(Model):
- """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer."""
-
- def __init__(
- self,
- network_fn: str,
- dataset: str,
- network_args: Optional[Dict] = None,
- dataset_args: Optional[Dict] = None,
- metrics: Optional[Dict] = None,
- criterion: Optional[Callable] = None,
- criterion_args: Optional[Dict] = None,
- optimizer: Optional[Callable] = None,
- optimizer_args: Optional[Dict] = None,
- lr_scheduler: Optional[Callable] = None,
- lr_scheduler_args: Optional[Dict] = None,
- swa_args: Optional[Dict] = None,
- device: Optional[str] = None,
- ) -> None:
- super().__init__(
- network_fn,
- dataset,
- network_args,
- dataset_args,
- metrics,
- criterion,
- criterion_args,
- optimizer,
- optimizer_args,
- lr_scheduler,
- lr_scheduler_args,
- swa_args,
- device,
- )
- self.init_token = dataset_args["args"]["init_token"]
- self.pad_token = dataset_args["args"]["pad_token"]
- self.eos_token = dataset_args["args"]["eos_token"]
- self.lower = dataset_args["args"]["lower"]
- self.max_len = 100
-
- if self._mapper is None:
- self._mapper = EmnistMapper(
- init_token=self.init_token,
- pad_token=self.pad_token,
- eos_token=self.eos_token,
- lower=self.lower,
- )
- self.tensor_transform = transforms.Compose(
- [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])]
- )
- self.softmax = nn.Softmax(dim=2)
-
- @torch.no_grad()
- def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
- src = self.network.extract_image_features(image)
-
- # Added for vqvae transformer.
- if isinstance(src, Tuple):
- src = src[0]
-
- memory = self.network.encoder(src)
-
- confidence_of_predictions = []
- trg_indices = [self.mapper(self.init_token)]
-
- for _ in range(self.max_len - 1):
- trg = torch.tensor(trg_indices, device=self.device)[None, :].long()
- trg = self.network.target_embedding(trg)
- logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None)
-
- # Convert logits to probabilities.
- probs = self.softmax(logits)
-
- pred_token = probs.argmax(2)[:, -1].item()
- confidence = probs.max(2).values[:, -1].item()
-
- trg_indices.append(pred_token)
- confidence_of_predictions.append(confidence)
-
- if pred_token == self.mapper(self.eos_token):
- break
-
- confidence = np.min(confidence_of_predictions)
- predicted_characters = "".join([self.mapper(x) for x in trg_indices[1:]])
-
- return predicted_characters, confidence
-
- @torch.no_grad()
- def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
- """Predict on a single input."""
- self.eval()
-
- if image.dtype == np.uint8:
- # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1].
- image = self.tensor_transform(image)
-
- # Rescale image between 0 and 1.
- if image.dtype == torch.uint8:
- # If the image is an unscaled tensor.
- image = image.type("torch.FloatTensor") / 255
-
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
-
- (predicted_characters, confidence_of_prediction,) = self._generate_sentence(
- image
- )
-
- return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/vqvae_model.py b/src/text_recognizer/models/vqvae_model.py
deleted file mode 100644
index 70f6f1f..0000000
--- a/src/text_recognizer/models/vqvae_model.py
+++ /dev/null
@@ -1,80 +0,0 @@
-"""Defines the VQVAEModel class."""
-from typing import Callable, Dict, Optional, Tuple, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch.utils.data import Dataset
-from torchvision.transforms import ToTensor
-
-from text_recognizer.datasets import EmnistMapper
-from text_recognizer.models.base import Model
-
-
-class VQVAEModel(Model):
- """Model for reconstructing images from codebook."""
-
- def __init__(
- self,
- network_fn: Type[nn.Module],
- dataset: Type[Dataset],
- network_args: Optional[Dict] = None,
- dataset_args: Optional[Dict] = None,
- metrics: Optional[Dict] = None,
- criterion: Optional[Callable] = None,
- criterion_args: Optional[Dict] = None,
- optimizer: Optional[Callable] = None,
- optimizer_args: Optional[Dict] = None,
- lr_scheduler: Optional[Callable] = None,
- lr_scheduler_args: Optional[Dict] = None,
- swa_args: Optional[Dict] = None,
- device: Optional[str] = None,
- ) -> None:
- """Initializes the CharacterModel."""
-
- super().__init__(
- network_fn,
- dataset,
- network_args,
- dataset_args,
- metrics,
- criterion,
- criterion_args,
- optimizer,
- optimizer_args,
- lr_scheduler,
- lr_scheduler_args,
- swa_args,
- device,
- )
- self.pad_token = dataset_args["args"]["pad_token"]
- if self._mapper is None:
- self._mapper = EmnistMapper(pad_token=self.pad_token,)
- self.tensor_transform = ToTensor()
- self.softmax = nn.Softmax(dim=0)
-
- @torch.no_grad()
- def predict_on_image(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
- """Reconstruction of image.
-
- Args:
- image (Union[np.ndarray, torch.Tensor]): An image containing a character.
-
- Returns:
- Tuple[str, float]: The predicted character and the confidence in the prediction.
-
- """
- self.eval()
-
- if image.dtype == np.uint8:
- # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
- image = self.tensor_transform(image)
- if image.dtype == torch.uint8:
- # If the image is an unscaled tensor.
- image = image.type("torch.FloatTensor") / 255
-
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
- image_reconstructed, _ = self.forward(image)
-
- return image_reconstructed