summaryrefslogtreecommitdiff
path: root/text_recognizer/models/base.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /text_recognizer/models/base.py
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r--text_recognizer/models/base.py455
1 files changed, 455 insertions, 0 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
new file mode 100644
index 0000000..70f4cdb
--- /dev/null
+++ b/text_recognizer/models/base.py
@@ -0,0 +1,455 @@
+"""Abstract Model class for PyTorch neural networks."""
+
+from abc import ABC, abstractmethod
+from glob import glob
+import importlib
+from pathlib import Path
+import re
+import shutil
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+
+from loguru import logger
+import torch
+from torch import nn
+from torch import Tensor
+from torch.optim.swa_utils import AveragedModel, SWALR
+from torch.utils.data import DataLoader, Dataset, random_split
+from torchsummary import summary
+
+from text_recognizer import datasets
+from text_recognizer import networks
+from text_recognizer.datasets import EmnistMapper
+
+WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
+
+
+class Model(ABC):
+ """Abstract Model class with composition of different parts defining a PyTorch neural network."""
+
+ def __init__(
+ self,
+ network_fn: str,
+ dataset: str,
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ """Base class, to be inherited by model for specific type of data.
+
+ Args:
+ network_fn (str): The name of network.
+ dataset (str): The name dataset class.
+ network_args (Optional[Dict]): Arguments for the network. Defaults to None.
+ dataset_args (Optional[Dict]): Arguments for the dataset.
+ metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
+ criterion (Optional[Callable]): The criterion to evaluate the performance of the network.
+ Defaults to None.
+ criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None.
+ optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None.
+ optimizer_args (Optional[Dict]): Dict of arguments for optimizer. Defaults to None.
+ lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None.
+ lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to
+ None.
+ swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to
+ None.
+ device (Optional[str]): Name of the device to train on. Defaults to None.
+
+ """
+ self._name = f"{self.__class__.__name__}_{dataset}_{network_fn}"
+ # Has to be set in subclass.
+ self._mapper = None
+
+ # Placeholder.
+ self._input_shape = None
+
+ self.dataset_name = dataset
+ self.dataset = None
+ self.dataset_args = dataset_args
+
+ # Placeholders for datasets.
+ self.train_dataset = None
+ self.val_dataset = None
+ self.test_dataset = None
+
+ # Stochastic Weight Averaging placeholders.
+ self.swa_args = swa_args
+ self._swa_scheduler = None
+ self._swa_network = None
+ self._use_swa_model = False
+
+ # Experiment directory.
+ self.model_dir = None
+
+ # Flag for configured model.
+ self.is_configured = False
+ self.data_prepared = False
+
+ # Flag for stopping training.
+ self.stop_training = False
+
+ self._metrics = metrics if metrics is not None else None
+
+ # Set the device.
+ self._device = (
+ torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ if device is None
+ else device
+ )
+
+ # Configure network.
+ self._network = None
+ self._network_args = network_args
+ self._configure_network(network_fn)
+
+ # Place network on device (GPU).
+ self.to_device()
+
+ # Loss and Optimizer placeholders for before loading.
+ self._criterion = criterion
+ self.criterion_args = criterion_args
+
+ self._optimizer = optimizer
+ self.optimizer_args = optimizer_args
+
+ self._lr_scheduler = lr_scheduler
+ self.lr_scheduler_args = lr_scheduler_args
+
+ def configure_model(self) -> None:
+ """Configures criterion and optimizers."""
+ if not self.is_configured:
+ self._configure_criterion()
+ self._configure_optimizers()
+
+ # Set this flag to true to prevent the model from configuring again.
+ self.is_configured = True
+
+ def prepare_data(self) -> None:
+ """Prepare data for training."""
+ # TODO add downloading.
+ if not self.data_prepared:
+ # Load dataset module.
+ self.dataset = getattr(datasets, self.dataset_name)
+
+ # Load train dataset.
+ train_dataset = self.dataset(train=True, **self.dataset_args["args"])
+ train_dataset.load_or_generate_data()
+
+ # Set input shape.
+ self._input_shape = train_dataset.input_shape
+
+ # Split train dataset into a training and validation partition.
+ dataset_len = len(train_dataset)
+ train_len = int(
+ self.dataset_args["train_args"]["train_fraction"] * dataset_len
+ )
+ val_len = dataset_len - train_len
+ self.train_dataset, self.val_dataset = random_split(
+ train_dataset, lengths=[train_len, val_len]
+ )
+
+ # Load test dataset.
+ self.test_dataset = self.dataset(train=False, **self.dataset_args["args"])
+ self.test_dataset.load_or_generate_data()
+
+ # Set the flag to true to disable ability to load data again.
+ self.data_prepared = True
+
+ def train_dataloader(self) -> DataLoader:
+ """Returns data loader for training set."""
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=True,
+ pin_memory=True,
+ )
+
+ def val_dataloader(self) -> DataLoader:
+ """Returns data loader for validation set."""
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=True,
+ pin_memory=True,
+ )
+
+ def test_dataloader(self) -> DataLoader:
+ """Returns data loader for test set."""
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=False,
+ pin_memory=True,
+ )
+
+ def _configure_network(self, network_fn: Type[nn.Module]) -> None:
+ """Loads the network."""
+ # If no network arguments are given, load pretrained weights if they exist.
+ # Load network module.
+ network_fn = getattr(networks, network_fn)
+ if self._network_args is None:
+ self.load_weights(network_fn)
+ else:
+ self._network = network_fn(**self._network_args)
+
+ def _configure_criterion(self) -> None:
+ """Loads the criterion."""
+ self._criterion = (
+ self._criterion(**self.criterion_args)
+ if self._criterion is not None
+ else None
+ )
+
+ def _configure_optimizers(self,) -> None:
+ """Loads the optimizers."""
+ if self._optimizer is not None:
+ self._optimizer = self._optimizer(
+ self._network.parameters(), **self.optimizer_args
+ )
+ else:
+ self._optimizer = None
+
+ if self._optimizer and self._lr_scheduler is not None:
+ if "steps_per_epoch" in self.lr_scheduler_args:
+ self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader())
+
+ # Assume lr scheduler should update at each epoch if not specified.
+ if "interval" not in self.lr_scheduler_args:
+ interval = "epoch"
+ else:
+ interval = self.lr_scheduler_args.pop("interval")
+ self._lr_scheduler = {
+ "lr_scheduler": self._lr_scheduler(
+ self._optimizer, **self.lr_scheduler_args
+ ),
+ "interval": interval,
+ }
+
+ if self.swa_args is not None:
+ self._swa_scheduler = {
+ "swa_scheduler": SWALR(self._optimizer, swa_lr=self.swa_args["lr"]),
+ "swa_start": self.swa_args["start"],
+ }
+ self._swa_network = AveragedModel(self._network).to(self.device)
+
+ @property
+ def name(self) -> str:
+ """Returns the name of the model."""
+ return self._name
+
+ @property
+ def input_shape(self) -> Tuple[int, ...]:
+ """The input shape."""
+ return self._input_shape
+
+ @property
+ def mapper(self) -> EmnistMapper:
+ """Returns the mapper that maps between ints and chars."""
+ return self._mapper
+
+ @property
+ def mapping(self) -> Dict:
+ """Returns the mapping between network output and Emnist character."""
+ return self._mapper.mapping if self._mapper is not None else None
+
+ def eval(self) -> None:
+ """Sets the network to evaluation mode."""
+ self._network.eval()
+
+ def train(self) -> None:
+ """Sets the network to train mode."""
+ self._network.train()
+
+ @property
+ def device(self) -> str:
+ """Device where the weights are stored, i.e. cpu or cuda."""
+ return self._device
+
+ @property
+ def metrics(self) -> Optional[Dict]:
+ """Metrics."""
+ return self._metrics
+
+ @property
+ def criterion(self) -> Optional[Callable]:
+ """Criterion."""
+ return self._criterion
+
+ @property
+ def optimizer(self) -> Optional[Callable]:
+ """Optimizer."""
+ return self._optimizer
+
+ @property
+ def lr_scheduler(self) -> Optional[Dict]:
+ """Returns a directory with the learning rate scheduler."""
+ return self._lr_scheduler
+
+ @property
+ def swa_scheduler(self) -> Optional[Dict]:
+ """Returns a directory with the stochastic weight averaging scheduler."""
+ return self._swa_scheduler
+
+ @property
+ def swa_network(self) -> Optional[Callable]:
+ """Returns the stochastic weight averaging network."""
+ return self._swa_network
+
+ @property
+ def network(self) -> Type[nn.Module]:
+ """Neural network."""
+ # Returns the SWA network if available.
+ return self._network
+
+ @property
+ def weights_filename(self) -> str:
+ """Filepath to the network weights."""
+ WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True)
+ return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt")
+
+ def use_swa_model(self) -> None:
+ """Set to use predictions from SWA model."""
+ if self.swa_network is not None:
+ self._use_swa_model = True
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Feedforward pass with the network."""
+ if self._use_swa_model:
+ return self.swa_network(x)
+ else:
+ return self.network(x)
+
+ def summary(
+ self,
+ input_shape: Optional[Union[List, Tuple]] = None,
+ depth: int = 3,
+ device: Optional[str] = None,
+ ) -> None:
+ """Prints a summary of the network architecture."""
+ device = self.device if device is None else device
+
+ if input_shape is not None:
+ summary(self.network, input_shape, depth=depth, device=device)
+ elif self._input_shape is not None:
+ input_shape = tuple(self._input_shape)
+ summary(self.network, input_shape, depth=depth, device=device)
+ else:
+ logger.warning("Could not print summary as input shape is not set.")
+
+ def to_device(self) -> None:
+ """Places the network on the device (GPU)."""
+ self._network.to(self._device)
+
+ def _get_state_dict(self) -> Dict:
+ """Get the state dict of the model."""
+ state = {"model_state": self._network.state_dict()}
+
+ if self._optimizer is not None:
+ state["optimizer_state"] = self._optimizer.state_dict()
+
+ if self._lr_scheduler is not None:
+ state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict()
+ state["scheduler_interval"] = self._lr_scheduler["interval"]
+
+ if self._swa_network is not None:
+ state["swa_network"] = self._swa_network.state_dict()
+
+ return state
+
+ def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None:
+ """Load a previously saved checkpoint.
+
+ Args:
+ checkpoint_path (Path): Path to the experiment with the checkpoint.
+
+ """
+ checkpoint_path = Path(checkpoint_path)
+ self.prepare_data()
+ self.configure_model()
+ logger.debug("Loading checkpoint...")
+ if not checkpoint_path.exists():
+ logger.debug("File does not exist {str(checkpoint_path)}")
+
+ checkpoint = torch.load(str(checkpoint_path), map_location=self.device)
+ self._network.load_state_dict(checkpoint["model_state"])
+
+ if self._optimizer is not None:
+ self._optimizer.load_state_dict(checkpoint["optimizer_state"])
+
+ if self._lr_scheduler is not None:
+ # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs
+ # with OneCycleLR.
+ if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR":
+ self._lr_scheduler["lr_scheduler"].load_state_dict(
+ checkpoint["scheduler_state"]
+ )
+ self._lr_scheduler["interval"] = checkpoint["scheduler_interval"]
+
+ if self._swa_network is not None:
+ self._swa_network.load_state_dict(checkpoint["swa_network"])
+
+ def save_checkpoint(
+ self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str
+ ) -> None:
+ """Saves a checkpoint of the model.
+
+ Args:
+ checkpoint_path (Path): Path to the experiment with the checkpoint.
+ is_best (bool): If it is the currently best model.
+ epoch (int): The epoch of the checkpoint.
+ val_metric (str): Validation metric.
+
+ """
+ state = self._get_state_dict()
+ state["is_best"] = is_best
+ state["epoch"] = epoch
+ state["network_args"] = self._network_args
+
+ checkpoint_path.mkdir(parents=True, exist_ok=True)
+
+ logger.debug("Saving checkpoint...")
+ filepath = str(checkpoint_path / "last.pt")
+ torch.save(state, filepath)
+
+ if is_best:
+ logger.debug(
+ f"Found a new best {val_metric}. Saving best checkpoint and weights."
+ )
+ shutil.copyfile(filepath, str(checkpoint_path / "best.pt"))
+
+ def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None:
+ """Load the network weights."""
+ logger.debug("Loading network with pretrained weights.")
+ filename = glob(self.weights_filename)[0]
+ if not filename:
+ raise FileNotFoundError(
+ f"Could not find any pretrained weights at {self.weights_filename}"
+ )
+ # Loading state directory.
+ state_dict = torch.load(filename, map_location=torch.device(self._device))
+ self._network_args = state_dict["network_args"]
+ weights = state_dict["model_state"]
+
+ # Initializes the network with trained weights.
+ if network_fn is not None:
+ self._network = network_fn(**self._network_args)
+ self._network.load_state_dict(weights)
+
+ if "swa_network" in state_dict:
+ self._swa_network = AveragedModel(self._network).to(self.device)
+ self._swa_network.load_state_dict(state_dict["swa_network"])
+
+ def save_weights(self, path: Path) -> None:
+ """Save the network weights."""
+ logger.debug("Saving the best network weights.")
+ shutil.copyfile(str(path / "best.pt"), self.weights_filename)