diff options
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r-- | src/text_recognizer/models/__init__.py | 5 | ||||
-rw-r--r-- | src/text_recognizer/models/base.py | 331 | ||||
-rw-r--r-- | src/text_recognizer/models/character_model.py | 20 | ||||
-rw-r--r-- | src/text_recognizer/models/line_ctc_model.py | 105 | ||||
-rw-r--r-- | src/text_recognizer/models/metrics.py | 80 |
5 files changed, 417 insertions, 124 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index ff10a07..a3cfc15 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -1,6 +1,7 @@ """Model modules.""" from .base import Model from .character_model import CharacterModel -from .metrics import accuracy +from .line_ctc_model import LineCTCModel +from .metrics import accuracy, cer, wer -__all__ = ["Model", "CharacterModel", "accuracy"] +__all__ = ["Model", "cer", "CharacterModel", "LineCTCModel", "accuracy", "wer"] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index 3a84a11..153e19a 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from glob import glob +import importlib from pathlib import Path import re import shutil @@ -10,9 +11,12 @@ from typing import Callable, Dict, Optional, Tuple, Type from loguru import logger import torch from torch import nn +from torch import Tensor +from torch.optim.swa_utils import AveragedModel, SWALR +from torch.utils.data import DataLoader, Dataset, random_split from torchsummary import summary -from text_recognizer.datasets import EmnistMapper, fetch_data_loaders +from text_recognizer.datasets import EmnistMapper WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" @@ -23,8 +27,9 @@ class Model(ABC): def __init__( self, network_fn: Type[nn.Module], + dataset: Type[Dataset], network_args: Optional[Dict] = None, - data_loader_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, metrics: Optional[Dict] = None, criterion: Optional[Callable] = None, criterion_args: Optional[Dict] = None, @@ -32,14 +37,16 @@ class Model(ABC): optimizer_args: Optional[Dict] = None, lr_scheduler: Optional[Callable] = None, lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, device: Optional[str] = None, ) -> None: """Base class, to be inherited by model for specific type of data. Args: network_fn (Type[nn.Module]): The PyTorch network. + dataset (Type[Dataset]): A dataset class. network_args (Optional[Dict]): Arguments for the network. Defaults to None. - data_loader_args (Optional[Dict]): Arguments for the DataLoader. + dataset_args (Optional[Dict]): Arguments for the dataset. metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. criterion (Optional[Callable]): The criterion to evaulate the preformance of the network. Defaults to None. @@ -49,107 +56,181 @@ class Model(ABC): lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None. lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to None. + swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to + None. device (Optional[str]): Name of the device to train on. Defaults to None. """ + # Has to be set in subclass. + self._mapper = None - # Configure data loaders and dataset info. - dataset_name, self._data_loaders, self._mapper = self._configure_data_loader( - data_loader_args - ) - self._input_shape = self._mapper.input_shape + # Placeholder. + self._input_shape = None + + self.dataset = dataset + self.dataset_args = dataset_args + + # Placeholders for datasets. + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None - self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}" + # Stochastic Weight Averaging placeholders. + self.swa_args = swa_args + self._swa_start = None + self._swa_scheduler = None + self._swa_network = None - if metrics is not None: - self._metrics = metrics + # Experiment directory. + self.model_dir = None + + # Flag for configured model. + self.is_configured = False + self.data_prepared = False + + # Flag for stopping training. + self.stop_training = False + + self._name = ( + f"{self.__class__.__name__}_{dataset.__name__}_{network_fn.__name__}" + ) + + self._metrics = metrics if metrics is not None else None # Set the device. - if device is None: - self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - else: - self._device = device + self._device = ( + torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device is None + else device + ) # Configure network. - self._network, self._network_args = self._configure_network( - network_fn, network_args - ) + self._network = None + self._network_args = network_args + self._configure_network(network_fn) - # To device. - self._network.to(self._device) + # Place network on device (GPU). + self.to_device() + + # Loss and Optimizer placeholders for before loading. + self._criterion = criterion + self.criterion_args = criterion_args + + self._optimizer = optimizer + self.optimizer_args = optimizer_args + + self._lr_scheduler = lr_scheduler + self.lr_scheduler_args = lr_scheduler_args + + def configure_model(self) -> None: + """Configures criterion and optimizers.""" + if not self.is_configured: + self._configure_criterion() + self._configure_optimizers() + + # Prints a summary of the network in terminal. + self.summary() + + # Set this flag to true to prevent the model from configuring again. + self.is_configured = True + + def prepare_data(self) -> None: + """Prepare data for training.""" + # TODO add downloading. + if not self.data_prepared: + # Load train dataset. + train_dataset = self.dataset(train=True, **self.dataset_args["args"]) + + # Set input shape. + self._input_shape = train_dataset.input_shape + + # Split train dataset into a training and validation partition. + dataset_len = len(train_dataset) + train_len = int( + self.dataset_args["train_args"]["train_fraction"] * dataset_len + ) + val_len = dataset_len - train_len + self.train_dataset, self.val_dataset = random_split( + train_dataset, lengths=[train_len, val_len] + ) + + # Load test dataset. + self.test_dataset = self.dataset(train=False, **self.dataset_args["args"]) + + # Set the flag to true to disable ability to load data agian. + self.data_prepared = True - # Configure training objects. - self._criterion = self._configure_criterion(criterion, criterion_args) - self._optimizer, self._lr_scheduler = self._configure_optimizers( - optimizer, optimizer_args, lr_scheduler, lr_scheduler_args + def train_dataloader(self) -> DataLoader: + """Returns data loader for training set.""" + return DataLoader( + self.train_dataset, + batch_size=self.dataset_args["train_args"]["batch_size"], + num_workers=self.dataset_args["train_args"]["num_workers"], + shuffle=True, + pin_memory=True, ) - # Experiment directory. - self.model_dir = None + def val_dataloader(self) -> DataLoader: + """Returns data loader for validation set.""" + return DataLoader( + self.val_dataset, + batch_size=self.dataset_args["train_args"]["batch_size"], + num_workers=self.dataset_args["train_args"]["num_workers"], + shuffle=True, + pin_memory=True, + ) - # Flag for stopping training. - self.stop_training = False + def test_dataloader(self) -> DataLoader: + """Returns data loader for test set.""" + return DataLoader( + self.test_dataset, + batch_size=self.dataset_args["train_args"]["batch_size"], + num_workers=self.dataset_args["train_args"]["num_workers"], + shuffle=False, + pin_memory=True, + ) - def _configure_data_loader( - self, data_loader_args: Optional[Dict] - ) -> Tuple[str, Dict, EmnistMapper]: - """Loads data loader, dataset name, and dataset mapper.""" - if data_loader_args is not None: - data_loaders = fetch_data_loaders(**data_loader_args) - dataset = list(data_loaders.values())[0].dataset - dataset_name = dataset.__name__ - mapper = dataset.mapper - else: - self._mapper = EmnistMapper() - dataset_name = "*" - data_loaders = None - return dataset_name, data_loaders, mapper - - def _configure_network( - self, network_fn: Type[nn.Module], network_args: Optional[Dict] - ) -> Tuple[Type[nn.Module], Dict]: + def _configure_network(self, network_fn: Type[nn.Module]) -> None: """Loads the network.""" # If no network arguemnts are given, load pretrained weights if they exist. - if network_args is None: - network, network_args = self.load_weights(network_fn) + if self._network_args is None: + self.load_weights(network_fn) else: - network = network_fn(**network_args) - return network, network_args + self._network = network_fn(**self._network_args) - def _configure_criterion( - self, criterion: Optional[Callable], criterion_args: Optional[Dict] - ) -> Optional[Callable]: + def _configure_criterion(self) -> None: """Loads the criterion.""" - if criterion is not None: - _criterion = criterion(**criterion_args) - else: - _criterion = None - return _criterion + self._criterion = ( + self._criterion(**self.criterion_args) + if self._criterion is not None + else None + ) - def _configure_optimizers( - self, - optimizer: Optional[Callable], - optimizer_args: Optional[Dict], - lr_scheduler: Optional[Callable], - lr_scheduler_args: Optional[Dict], - ) -> Tuple[Optional[Callable], Optional[Callable]]: + def _configure_optimizers(self,) -> None: """Loads the optimizers.""" - if optimizer is not None: - _optimizer = optimizer(self._network.parameters(), **optimizer_args) + if self._optimizer is not None: + self._optimizer = self._optimizer( + self._network.parameters(), **self.optimizer_args + ) else: - _optimizer = None + self._optimizer = None - if _optimizer and lr_scheduler is not None: - if "OneCycleLR" in str(lr_scheduler): - lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"]) - _lr_scheduler = lr_scheduler(_optimizer, **lr_scheduler_args) + if self._optimizer and self._lr_scheduler is not None: + if "OneCycleLR" in str(self._lr_scheduler): + self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader()) + self._lr_scheduler = self._lr_scheduler( + self._optimizer, **self.lr_scheduler_args + ) else: - _lr_scheduler = None + self._lr_scheduler = None - return _optimizer, _lr_scheduler + if self.swa_args is not None: + self._swa_start = self.swa_args["start"] + self._swa_scheduler = SWALR(self._optimizer, swa_lr=self.swa_args["lr"]) + self._swa_network = AveragedModel(self._network).to(self.device) @property - def __name__(self) -> str: + def name(self) -> str: """Returns the name of the model.""" return self._name @@ -159,7 +240,7 @@ class Model(ABC): return self._input_shape @property - def mapper(self) -> Dict: + def mapper(self) -> EmnistMapper: """Returns the mapper that maps between ints and chars.""" return self._mapper @@ -202,13 +283,24 @@ class Model(ABC): return self._lr_scheduler @property - def data_loaders(self) -> Optional[Dict]: - """Dataloaders.""" - return self._data_loaders + def swa_scheduler(self) -> Optional[Callable]: + """Returns the stochastic weight averaging scheduler.""" + return self._swa_scheduler + + @property + def swa_start(self) -> Optional[Callable]: + """Returns the start epoch of stochastic weight averaging.""" + return self._swa_start @property - def network(self) -> nn.Module: + def swa_network(self) -> Optional[Callable]: + """Returns the stochastic weight averaging network.""" + return self._swa_network + + @property + def network(self) -> Type[nn.Module]: """Neural network.""" + # Returns the SWA network if available. return self._network @property @@ -217,15 +309,27 @@ class Model(ABC): WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True) return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt") - def summary(self) -> None: + def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor: + """Compute the loss.""" + return self.criterion(output, targets) + + def summary( + self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 5 + ) -> None: """Prints a summary of the network architecture.""" - device = re.sub("[^A-Za-z]+", "", self.device) - if self._input_shape is not None: + + if input_shape is not None: + summary(self._network, input_shape, depth=depth, device=self.device) + elif self._input_shape is not None: input_shape = (1,) + tuple(self._input_shape) - summary(self._network, input_shape, device=device) + summary(self._network, input_shape, depth=depth, device=self.device) else: logger.warning("Could not print summary as input shape is not set.") + def to_device(self) -> None: + """Places the network on the device (GPU).""" + self._network.to(self._device) + def _get_state_dict(self) -> Dict: """Get the state dict of the model.""" state = {"model_state": self._network.state_dict()} @@ -236,69 +340,67 @@ class Model(ABC): if self._lr_scheduler is not None: state["scheduler_state"] = self._lr_scheduler.state_dict() + if self._swa_network is not None: + state["swa_network"] = self._swa_network.state_dict() + return state - def load_checkpoint(self, path: Path) -> int: + def load_from_checkpoint(self, checkpoint_path: Path) -> None: """Load a previously saved checkpoint. Args: - path (Path): Path to the experiment with the checkpoint. - - Returns: - epoch (int): The last epoch when the checkpoint was created. + checkpoint_path (Path): Path to the experiment with the checkpoint. """ logger.debug("Loading checkpoint...") - if not path.exists(): - logger.debug("File does not exist {str(path)}") + if not checkpoint_path.exists(): + logger.debug("File does not exist {str(checkpoint_path)}") - checkpoint = torch.load(str(path)) + checkpoint = torch.load(str(checkpoint_path)) self._network.load_state_dict(checkpoint["model_state"]) if self._optimizer is not None: self._optimizer.load_state_dict(checkpoint["optimizer_state"]) - # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs. - # if self._lr_scheduler is not None: - # self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) - - epoch = checkpoint["epoch"] + if self._lr_scheduler is not None: + # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs + # with OneCycleLR. + if self._lr_scheduler.__class__.__name__ != "OneCycleLR": + self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) - return epoch + if self._swa_network is not None: + self._swa_network.load_state_dict(checkpoint["swa_network"]) - def save_checkpoint(self, is_best: bool, epoch: int, val_metric: str) -> None: + def save_checkpoint( + self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str + ) -> None: """Saves a checkpoint of the model. Args: + checkpoint_path (Path): Path to the experiment with the checkpoint. is_best (bool): If it is the currently best model. epoch (int): The epoch of the checkpoint. val_metric (str): Validation metric. - Raises: - ValueError: If the self.model_dir is not set. - """ state = self._get_state_dict() state["is_best"] = is_best state["epoch"] = epoch state["network_args"] = self._network_args - if self.model_dir is None: - raise ValueError("Experiment directory is not set.") - - self.model_dir.mkdir(parents=True, exist_ok=True) + checkpoint_path.mkdir(parents=True, exist_ok=True) logger.debug("Saving checkpoint...") - filepath = str(self.model_dir / "last.pt") + filepath = str(checkpoint_path / "last.pt") torch.save(state, filepath) if is_best: logger.debug( f"Found a new best {val_metric}. Saving best checkpoint and weights." ) - shutil.copyfile(filepath, str(self.model_dir / "best.pt")) + shutil.copyfile(filepath, str(checkpoint_path / "best.pt")) - def load_weights(self, network_fn: Type[nn.Module]) -> Tuple[Type[nn.Module], Dict]: + def load_weights(self, network_fn: Type[nn.Module]) -> None: """Load the network weights.""" logger.debug("Loading network with pretrained weights.") filename = glob(self.weights_filename)[0] @@ -308,13 +410,16 @@ class Model(ABC): ) # Loading state directory. state_dict = torch.load(filename, map_location=torch.device(self._device)) - network_args = state_dict["network_args"] + self._network_args = state_dict["network_args"] weights = state_dict["model_state"] # Initializes the network with trained weights. - network = network_fn(**self._network_args) - network.load_state_dict(weights) - return network, network_args + self._network = network_fn(**self._network_args) + self._network.load_state_dict(weights) + + if "swa_network" in state_dict: + self._swa_network = AveragedModel(self._network).to(self.device) + self._swa_network.load_state_dict(state_dict["swa_network"]) def save_weights(self, path: Path) -> None: """Save the network weights.""" diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 0fd7afd..64ba693 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -4,8 +4,10 @@ from typing import Callable, Dict, Optional, Tuple, Type, Union import numpy as np import torch from torch import nn +from torch.utils.data import Dataset from torchvision.transforms import ToTensor +from text_recognizer.datasets import EmnistMapper from text_recognizer.models.base import Model @@ -15,8 +17,9 @@ class CharacterModel(Model): def __init__( self, network_fn: Type[nn.Module], + dataset: Type[Dataset], network_args: Optional[Dict] = None, - data_loader_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, metrics: Optional[Dict] = None, criterion: Optional[Callable] = None, criterion_args: Optional[Dict] = None, @@ -24,14 +27,16 @@ class CharacterModel(Model): optimizer_args: Optional[Dict] = None, lr_scheduler: Optional[Callable] = None, lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, device: Optional[str] = None, ) -> None: """Initializes the CharacterModel.""" super().__init__( network_fn, + dataset, network_args, - data_loader_args, + dataset_args, metrics, criterion, criterion_args, @@ -39,8 +44,11 @@ class CharacterModel(Model): optimizer_args, lr_scheduler, lr_scheduler_args, + swa_args, device, ) + if self._mapper is None: + self._mapper = EmnistMapper() self.tensor_transform = ToTensor() self.softmax = nn.Softmax(dim=0) @@ -67,9 +75,13 @@ class CharacterModel(Model): # Put the image tensor on the device the model weights are on. image = image.to(self.device) - logits = self.network(image) + logits = ( + self.swa_network(image) + if self.swa_network is not None + else self.network(image) + ) - prediction = self.softmax(logits.data.squeeze()) + prediction = self.softmax(logits.squeeze(0)) index = int(torch.argmax(prediction, dim=0)) confidence_of_prediction = prediction[index] diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py new file mode 100644 index 0000000..97308a7 --- /dev/null +++ b/src/text_recognizer/models/line_ctc_model.py @@ -0,0 +1,105 @@ +"""Defines the LineCTCModel class.""" +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model +from text_recognizer.networks import greedy_decoder + + +class LineCTCModel(Model): + """Model for predicting a sequence of characters from an image of a text line.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + if self._mapper is None: + self._mapper = EmnistMapper() + self.tensor_transform = ToTensor() + + def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor: + """Computes the CTC loss. + + Args: + output (Tensor): Model predictions. + targets (Tensor): Correct output sequence. + + Returns: + Tensor: The CTC loss. + + """ + input_lengths = torch.full( + size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, + ) + target_lengths = torch.full( + size=(output.shape[1],), fill_value=targets.shape[1], dtype=torch.long, + ) + return self.criterion(output, targets, input_lengths, target_lengths) + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: + """Predict on a single input.""" + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + log_probs = ( + self.swa_network(image) + if self.swa_network is not None + else self.network(image) + ) + + raw_pred, _ = greedy_decoder( + predictions=log_probs, + character_mapper=self.mapper, + blank_label=79, + collapse_repeated=True, + ) + + log_probs, _ = log_probs.max(dim=2) + + predicted_characters = "".join(raw_pred[0]) + confidence_of_prediction = torch.exp(log_probs.sum()).item() + + return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py index ac8d68e..6a26216 100644 --- a/src/text_recognizer/models/metrics.py +++ b/src/text_recognizer/models/metrics.py @@ -1,19 +1,89 @@ """Utility functions for models.""" - +import Levenshtein as Lev import torch +from torch import Tensor + +from text_recognizer.networks import greedy_decoder -def accuracy(outputs: torch.Tensor, labels: torch.Tensor) -> float: +def accuracy(outputs: Tensor, labels: Tensor) -> float: """Computes the accuracy. Args: - outputs (torch.Tensor): The output from the network. - labels (torch.Tensor): Ground truth labels. + outputs (Tensor): The output from the network. + labels (Tensor): Ground truth labels. Returns: float: The accuracy for the batch. """ _, predicted = torch.max(outputs.data, dim=1) - acc = (predicted == labels).sum().item() / labels.shape[0] + acc = (predicted == labels).sum().float() / labels.shape[0] + acc = acc.item() return acc + + +def cer(outputs: Tensor, targets: Tensor) -> float: + """Computes the character error rate. + + Args: + outputs (Tensor): The output from the network. + targets (Tensor): Ground truth labels. + + Returns: + float: The cer for the batch. + + """ + target_lengths = torch.full( + size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, + ) + decoded_predictions, decoded_targets = greedy_decoder( + outputs, targets, target_lengths + ) + + lev_dist = 0 + + for prediction, target in zip(decoded_predictions, decoded_targets): + prediction = "".join(prediction) + target = "".join(target) + prediction, target = ( + prediction.replace(" ", ""), + target.replace(" ", ""), + ) + lev_dist += Lev.distance(prediction, target) + return lev_dist / len(decoded_predictions) + + +def wer(outputs: Tensor, targets: Tensor) -> float: + """Computes the Word error rate. + + Args: + outputs (Tensor): The output from the network. + targets (Tensor): Ground truth labels. + + Returns: + float: The wer for the batch. + + """ + target_lengths = torch.full( + size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, + ) + decoded_predictions, decoded_targets = greedy_decoder( + outputs, targets, target_lengths + ) + + lev_dist = 0 + + for prediction, target in zip(decoded_predictions, decoded_targets): + prediction = "".join(prediction) + target = "".join(target) + + b = set(prediction.split() + target.split()) + word2char = dict(zip(b, range(len(b)))) + + w1 = [chr(word2char[w]) for w in prediction.split()] + w2 = [chr(word2char[w]) for w in target.split()] + + lev_dist += Lev.distance("".join(w1), "".join(w2)) + + return lev_dist / len(decoded_predictions) |