diff options
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r-- | src/text_recognizer/models/base.py | 84 | ||||
-rw-r--r-- | src/text_recognizer/models/character_model.py | 32 |
2 files changed, 75 insertions, 41 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index b78eacb..84a86ca 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -22,7 +22,7 @@ class Model(ABC): def __init__( self, network_fn: Type[nn.Module], - network_args: Dict, + network_args: Optional[Dict] = None, data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, metrics: Optional[Dict] = None, @@ -38,7 +38,7 @@ class Model(ABC): Args: network_fn (Type[nn.Module]): The PyTorch network. - network_args (Dict): Arguments for the network. + network_args (Optional[Dict]): Arguments for the network. Defaults to None. data_loader (Optional[Callable]): A function that fetches train and val DataLoader. data_loader_args (Optional[Dict]): Arguments for the DataLoader. metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. @@ -58,18 +58,14 @@ class Model(ABC): if data_loader_args is not None: self._data_loaders = data_loader(**data_loader_args) dataset_name = self._data_loaders.__name__ + self._mapping = self._data_loaders.mapping else: + self._mapping = None dataset_name = "*" self._data_loaders = None self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}" - # Extract the input shape for the torchsummary. - if isinstance(network_args["input_size"], int): - self._input_shape = (1,) + tuple([network_args["input_size"]]) - else: - self._input_shape = (1,) + tuple(network_args["input_size"]) - if metrics is not None: self._metrics = metrics @@ -80,8 +76,13 @@ class Model(ABC): self._device = device # Load network. - self.network_args = network_args - self._network = network_fn(**self.network_args) + self._network = None + self._network_args = network_args + # If no network arguemnts are given, load pretrained weights if they exist. + if self._network_args is None: + self.load_weights(network_fn) + else: + self._network = network_fn(**self._network_args) # To device. self._network.to(self._device) @@ -104,8 +105,17 @@ class Model(ABC): lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders("train")) self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args) - # Class mapping. - self._mapping = None + # Extract the input shape for the torchsummary. + if isinstance(self._network_args["input_size"], int): + self._input_shape = (1,) + tuple([self._network_args["input_size"]]) + else: + self._input_shape = (1,) + tuple(self._network_args["input_size"]) + + # Experiment directory. + self.model_dir = None + + # Flag for stopping training. + self.stop_training = False @property def __name__(self) -> str: @@ -179,8 +189,13 @@ class Model(ABC): def _get_state_dict(self) -> Dict: """Get the state dict of the model.""" state = {"model_state": self._network.state_dict()} + if self._optimizer is not None: state["optimizer_state"] = self._optimizer.state_dict() + + if self._lr_scheduler is not None: + state["scheduler_state"] = self._lr_scheduler.state_dict() + return state def load_checkpoint(self, path: Path) -> int: @@ -203,54 +218,63 @@ class Model(ABC): if self._optimizer is not None: self._optimizer.load_state_dict(checkpoint["optimizer_state"]) + if self._lr_scheduler is not None: + self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) + epoch = checkpoint["epoch"] return epoch - def save_checkpoint( - self, path: Path, is_best: bool, epoch: int, val_metric: str - ) -> None: + def save_checkpoint(self, is_best: bool, epoch: int, val_metric: str) -> None: """Saves a checkpoint of the model. Args: - path (Path): Path to the experiment folder. is_best (bool): If it is the currently best model. epoch (int): The epoch of the checkpoint. val_metric (str): Validation metric. + Raises: + ValueError: If the self.model_dir is not set. + """ state = self._get_state_dict() state["is_best"] = is_best state["epoch"] = epoch - state["network_args"] = self.network_args + state["network_args"] = self._network_args - path.mkdir(parents=True, exist_ok=True) + if self.model_dir is None: + raise ValueError("Experiment directory is not set.") + + self.model_dir.mkdir(parents=True, exist_ok=True) logger.debug("Saving checkpoint...") - filepath = str(path / "last.pt") + filepath = str(self.model_dir / "last.pt") torch.save(state, filepath) if is_best: logger.debug( f"Found a new best {val_metric}. Saving best checkpoint and weights." ) - shutil.copyfile(filepath, str(path / "best.pt")) + shutil.copyfile(filepath, str(self.model_dir / "best.pt")) - def load_weights(self) -> None: + def load_weights(self, network_fn: Type[nn.Module]) -> None: """Load the network weights.""" - logger.debug("Loading network weights.") + logger.debug("Loading network with pretrained weights.") filename = glob(self.weights_filename)[0] - weights = torch.load(filename, map_location=torch.device(self._device))[ - "model_state" - ] + if not filename: + raise FileNotFoundError( + f"Could not find any pretrained weights at {self.weights_filename}" + ) + # Loading state directory. + state_dict = torch.load(filename, map_location=torch.device(self._device)) + self._network_args = state_dict["network_args"] + weights = state_dict["model_state"] + + # Initializes the network with trained weights. + self._network = network_fn(**self._network_args) self._network.load_state_dict(weights) def save_weights(self, path: Path) -> None: """Save the network weights.""" logger.debug("Saving the best network weights.") shutil.copyfile(str(path / "best.pt"), self.weights_filename) - - @abstractmethod - def load_mapping(self) -> None: - """Loads class mapping from network output to character.""" - ... diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 527fc7d..f1dabb7 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -1,12 +1,15 @@ """Defines the CharacterModel class.""" -from typing import Callable, Dict, Optional, Tuple, Type +from typing import Callable, Dict, Optional, Tuple, Type, Union import numpy as np import torch from torch import nn from torchvision.transforms import ToTensor -from text_recognizer.datasets.emnist_dataset import load_emnist_mapping +from text_recognizer.datasets.emnist_dataset import ( + _augment_emnist_mapping, + _load_emnist_essentials, +) from text_recognizer.models.base import Model @@ -16,7 +19,7 @@ class CharacterModel(Model): def __init__( self, network_fn: Type[nn.Module], - network_args: Dict, + network_args: Optional[Dict] = None, data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, metrics: Optional[Dict] = None, @@ -44,19 +47,23 @@ class CharacterModel(Model): lr_scheduler_args, device, ) - self.load_mapping() + if self.mapping is None: + self.load_mapping() self.tensor_transform = ToTensor() self.softmax = nn.Softmax(dim=0) def load_mapping(self) -> None: """Mapping between integers and classes.""" - self._mapping = load_emnist_mapping() + essentials = _load_emnist_essentials() + self._mapping = _augment_emnist_mapping(dict(essentials["mapping"])) - def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]: + def predict_on_image( + self, image: Union[np.ndarray, torch.Tensor] + ) -> Tuple[str, float]: """Character prediction on an image. Args: - image (np.ndarray): An image containing a character. + image (Union[np.ndarray, torch.Tensor]): An image containing a character. Returns: Tuple[str, float]: The predicted character and the confidence in the prediction. @@ -64,12 +71,15 @@ class CharacterModel(Model): """ if image.dtype == np.uint8: - image = (image / 255).astype(np.float32) - - # Conver to Pytorch Tensor. - image = self.tensor_transform(image) + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 with torch.no_grad(): + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) logits = self.network(image) prediction = self.softmax(logits.data.squeeze()) |