diff options
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r-- | src/text_recognizer/models/base.py | 66 |
1 files changed, 46 insertions, 20 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index 0cc531a..b78eacb 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -1,9 +1,11 @@ """Abstract Model class for PyTorch neural networks.""" from abc import ABC, abstractmethod +from glob import glob from pathlib import Path +import re import shutil -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple, Type from loguru import logger import torch @@ -19,7 +21,7 @@ class Model(ABC): def __init__( self, - network_fn: Callable, + network_fn: Type[nn.Module], network_args: Dict, data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, @@ -35,7 +37,7 @@ class Model(ABC): """Base class, to be inherited by model for specific type of data. Args: - network_fn (Callable): The PyTorch network. + network_fn (Type[nn.Module]): The PyTorch network. network_args (Dict): Arguments for the network. data_loader (Optional[Callable]): A function that fetches train and val DataLoader. data_loader_args (Optional[Dict]): Arguments for the DataLoader. @@ -57,27 +59,29 @@ class Model(ABC): self._data_loaders = data_loader(**data_loader_args) dataset_name = self._data_loaders.__name__ else: - dataset_name = "" + dataset_name = "*" self._data_loaders = None - self.name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}" + self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}" # Extract the input shape for the torchsummary. - self._input_shape = network_args.pop("input_shape") + if isinstance(network_args["input_size"], int): + self._input_shape = (1,) + tuple([network_args["input_size"]]) + else: + self._input_shape = (1,) + tuple(network_args["input_size"]) if metrics is not None: self._metrics = metrics # Set the device. - if self.device is None: - self._device = torch.device( - "cuda:0" if torch.cuda.is_available() else "cpu" - ) + if device is None: + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self._device = device # Load network. - self._network = network_fn(**network_args) + self.network_args = network_args + self._network = network_fn(**self.network_args) # To device. self._network.to(self._device) @@ -95,13 +99,29 @@ class Model(ABC): # Set learning rate scheduler. self._lr_scheduler = None if lr_scheduler is not None: + # OneCycleLR needs the number of steps in an epoch as an input argument. + if "OneCycleLR" in str(lr_scheduler): + lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders("train")) self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args) + # Class mapping. + self._mapping = None + + @property + def __name__(self) -> str: + """Returns the name of the model.""" + return self._name + @property def input_shape(self) -> Tuple[int, ...]: """The input shape.""" return self._input_shape + @property + def mapping(self) -> Dict: + """Returns the class mapping.""" + return self._mapping + def eval(self) -> None: """Sets the network to evaluation mode.""" self._network.eval() @@ -149,13 +169,14 @@ class Model(ABC): def weights_filename(self) -> str: """Filepath to the network weights.""" WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True) - return str(WEIGHT_DIRNAME / f"{self.name}_weights.pt") + return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt") def summary(self) -> None: """Prints a summary of the network architecture.""" - summary(self._network, self._input_shape, device=self.device) + device = re.sub("[^A-Za-z]+", "", self.device) + summary(self._network, self._input_shape, device=device) - def _get_state(self) -> Dict: + def _get_state_dict(self) -> Dict: """Get the state dict of the model.""" state = {"model_state": self._network.state_dict()} if self._optimizer is not None: @@ -172,6 +193,7 @@ class Model(ABC): epoch (int): The last epoch when the checkpoint was created. """ + logger.debug("Loading checkpoint...") if not path.exists(): logger.debug("File does not exist {str(path)}") @@ -200,6 +222,7 @@ class Model(ABC): state = self._get_state_dict() state["is_best"] = is_best state["epoch"] = epoch + state["network_args"] = self.network_args path.mkdir(parents=True, exist_ok=True) @@ -216,15 +239,18 @@ class Model(ABC): def load_weights(self) -> None: """Load the network weights.""" logger.debug("Loading network weights.") - weights = torch.load(self.weights_filename)["model_state"] + filename = glob(self.weights_filename)[0] + weights = torch.load(filename, map_location=torch.device(self._device))[ + "model_state" + ] self._network.load_state_dict(weights) - def save_weights(self) -> None: + def save_weights(self, path: Path) -> None: """Save the network weights.""" - logger.debug("Saving network weights.") - torch.save({"model_state": self._network.state_dict()}, self.weights_filename) + logger.debug("Saving the best network weights.") + shutil.copyfile(str(path / "best.pt"), self.weights_filename) @abstractmethod - def mapping(self) -> Dict: - """Mapping from network output to class.""" + def load_mapping(self) -> None: + """Loads class mapping from network output to character.""" ... |