summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/base.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-06-23 22:39:54 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-06-23 22:39:54 +0200
commit7c4de6d88664d2ea1b084f316a11896dde3e1150 (patch)
treecbde7e64c8064e9b523dfb65cd6c487d061ec805 /src/text_recognizer/models/base.py
parenta7a9ce59fc37317dd74c3a52caf7c4e68e570ee8 (diff)
latest
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r--src/text_recognizer/models/base.py230
1 files changed, 230 insertions, 0 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
new file mode 100644
index 0000000..736af7b
--- /dev/null
+++ b/src/text_recognizer/models/base.py
@@ -0,0 +1,230 @@
+"""Abstract Model class for PyTorch neural networks."""
+
+from abc import ABC, abstractmethod
+from pathlib import Path
+import shutil
+from typing import Callable, Dict, Optional, Tuple
+
+from loguru import logger
+import torch
+from torch import nn
+from torchsummary import summary
+
+from text_recognizer.dataset.data_loader import fetch_data_loader
+
+WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
+
+
+class Model(ABC):
+ """Abstract Model class with composition of different parts defining a PyTorch neural network."""
+
+ def __init__(
+ self,
+ network_fn: Callable,
+ network_args: Dict,
+ data_loader_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ """Base class, to be inherited by predictors for specific type of data.
+
+ Args:
+ network_fn (Callable): The PyTorch network.
+ network_args (Dict): Arguments for the network.
+ data_loader_args (Optional[Dict]): Arguments for the data loader.
+ metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
+ criterion (Optional[Callable]): The criterion to evaulate the preformance of the network.
+ Defaults to None.
+ criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None.
+ optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None.
+ optimizer_args (Optional[Dict]): Dict of arguments for optimizer. Defaults to None.
+ lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None.
+ lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to
+ None.
+ device (Optional[str]): Name of the device to train on. Defaults to None.
+
+ """
+
+ # Fetch data loaders.
+ if data_loader_args is not None:
+ self._data_loaders = fetch_data_loader(**data_loader_args)
+ dataset_name = self._data_loaders.items()[0].dataset.__name__
+ else:
+ dataset_name = ""
+ self._data_loaders = None
+
+ self.name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
+
+ # Extract the input shape for the torchsummary.
+ self._input_shape = network_args.pop("input_shape")
+
+ if metrics is not None:
+ self._metrics = metrics
+
+ # Set the device.
+ if self.device is None:
+ self._device = torch.device(
+ "cuda:0" if torch.cuda.is_available() else "cpu"
+ )
+ else:
+ self._device = device
+
+ # Load network.
+ self._network = network_fn(**network_args)
+
+ # To device.
+ self._network.to(self._device)
+
+ # Set criterion.
+ self._criterion = None
+ if criterion is not None:
+ self._criterion = criterion(**criterion_args)
+
+ # Set optimizer.
+ self._optimizer = None
+ if optimizer is not None:
+ self._optimizer = optimizer(self._network.parameters(), **optimizer_args)
+
+ # Set learning rate scheduler.
+ self._lr_scheduler = None
+ if lr_scheduler is not None:
+ self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
+
+ @property
+ def input_shape(self) -> Tuple[int, ...]:
+ """The input shape."""
+ return self._input_shape
+
+ def eval(self) -> None:
+ """Sets the network to evaluation mode."""
+ self._network.eval()
+
+ def train(self) -> None:
+ """Sets the network to train mode."""
+ self._network.train()
+
+ @property
+ def device(self) -> str:
+ """Device where the weights are stored, i.e. cpu or cuda."""
+ return self._device
+
+ @property
+ def metrics(self) -> Optional[Dict]:
+ """Metrics."""
+ return self._metrics
+
+ @property
+ def criterion(self) -> Optional[Callable]:
+ """Criterion."""
+ return self._criterion
+
+ @property
+ def optimizer(self) -> Optional[Callable]:
+ """Optimizer."""
+ return self._optimizer
+
+ @property
+ def lr_scheduler(self) -> Optional[Callable]:
+ """Learning rate scheduler."""
+ return self._lr_scheduler
+
+ @property
+ def data_loaders(self) -> Optional[Dict]:
+ """Dataloaders."""
+ return self._data_loaders
+
+ @property
+ def network(self) -> nn.Module:
+ """Neural network."""
+ return self._network
+
+ @property
+ def weights_filename(self) -> str:
+ """Filepath to the network weights."""
+ WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True)
+ return str(WEIGHT_DIRNAME / f"{self.name}_weights.pt")
+
+ def summary(self) -> None:
+ """Prints a summary of the network architecture."""
+ summary(self._network, self._input_shape, device=self.device)
+
+ def _get_state(self) -> Dict:
+ """Get the state dict of the model."""
+ state = {"model_state": self._network.state_dict()}
+ if self._optimizer is not None:
+ state["optimizer_state"] = self._optimizer.state_dict()
+ return state
+
+ def load_checkpoint(self, path: Path) -> int:
+ """Load a previously saved checkpoint.
+
+ Args:
+ path (Path): Path to the experiment with the checkpoint.
+
+ Returns:
+ epoch (int): The last epoch when the checkpoint was created.
+
+ """
+ if not path.exists():
+ logger.debug("File does not exist {str(path)}")
+
+ checkpoint = torch.load(str(path))
+ self._network.load_state_dict(checkpoint["model_state"])
+
+ if self._optimizer is not None:
+ self._optimizer.load_state_dict(checkpoint["optimizer_state"])
+
+ epoch = checkpoint["epoch"]
+
+ return epoch
+
+ def save_checkpoint(
+ self, path: Path, is_best: bool, epoch: int, val_metric: str
+ ) -> None:
+ """Saves a checkpoint of the model.
+
+ Args:
+ path (Path): Path to the experiment folder.
+ is_best (bool): If it is the currently best model.
+ epoch (int): The epoch of the checkpoint.
+ val_metric (str): Validation metric.
+
+ """
+ state = self._get_state_dict()
+ state["is_best"] = is_best
+ state["epoch"] = epoch
+
+ path.mkdir(parents=True, exist_ok=True)
+
+ logger.debug("Saving checkpoint...")
+ filepath = str(path / "last.pt")
+ torch.save(state, filepath)
+
+ if is_best:
+ logger.debug(
+ f"Found a new best {val_metric}. Saving best checkpoint and weights."
+ )
+ self.save_weights()
+ shutil.copyfile(filepath, str(path / "best.pt"))
+
+ def load_weights(self) -> None:
+ """Load the network weights."""
+ logger.debug("Loading network weights.")
+ weights = torch.load(self.weights_filename)["model_state"]
+ self._network.load_state_dict(weights)
+
+ def save_weights(self) -> None:
+ """Save the network weights."""
+ logger.debug("Saving network weights.")
+ torch.save({"model_state": self._network.state_dict()}, self.weights_filename)
+
+ @abstractmethod
+ def mapping(self) -> Dict:
+ """Mapping from network output to class."""
+ ...