diff options
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r-- | src/text_recognizer/models/base.py | 331 |
1 files changed, 218 insertions, 113 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index 3a84a11..153e19a 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from glob import glob +import importlib from pathlib import Path import re import shutil @@ -10,9 +11,12 @@ from typing import Callable, Dict, Optional, Tuple, Type from loguru import logger import torch from torch import nn +from torch import Tensor +from torch.optim.swa_utils import AveragedModel, SWALR +from torch.utils.data import DataLoader, Dataset, random_split from torchsummary import summary -from text_recognizer.datasets import EmnistMapper, fetch_data_loaders +from text_recognizer.datasets import EmnistMapper WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" @@ -23,8 +27,9 @@ class Model(ABC): def __init__( self, network_fn: Type[nn.Module], + dataset: Type[Dataset], network_args: Optional[Dict] = None, - data_loader_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, metrics: Optional[Dict] = None, criterion: Optional[Callable] = None, criterion_args: Optional[Dict] = None, @@ -32,14 +37,16 @@ class Model(ABC): optimizer_args: Optional[Dict] = None, lr_scheduler: Optional[Callable] = None, lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, device: Optional[str] = None, ) -> None: """Base class, to be inherited by model for specific type of data. Args: network_fn (Type[nn.Module]): The PyTorch network. + dataset (Type[Dataset]): A dataset class. network_args (Optional[Dict]): Arguments for the network. Defaults to None. - data_loader_args (Optional[Dict]): Arguments for the DataLoader. + dataset_args (Optional[Dict]): Arguments for the dataset. metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. criterion (Optional[Callable]): The criterion to evaulate the preformance of the network. Defaults to None. @@ -49,107 +56,181 @@ class Model(ABC): lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None. lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to None. + swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to + None. device (Optional[str]): Name of the device to train on. Defaults to None. """ + # Has to be set in subclass. + self._mapper = None - # Configure data loaders and dataset info. - dataset_name, self._data_loaders, self._mapper = self._configure_data_loader( - data_loader_args - ) - self._input_shape = self._mapper.input_shape + # Placeholder. + self._input_shape = None + + self.dataset = dataset + self.dataset_args = dataset_args + + # Placeholders for datasets. + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None - self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}" + # Stochastic Weight Averaging placeholders. + self.swa_args = swa_args + self._swa_start = None + self._swa_scheduler = None + self._swa_network = None - if metrics is not None: - self._metrics = metrics + # Experiment directory. + self.model_dir = None + + # Flag for configured model. + self.is_configured = False + self.data_prepared = False + + # Flag for stopping training. + self.stop_training = False + + self._name = ( + f"{self.__class__.__name__}_{dataset.__name__}_{network_fn.__name__}" + ) + + self._metrics = metrics if metrics is not None else None # Set the device. - if device is None: - self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - else: - self._device = device + self._device = ( + torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device is None + else device + ) # Configure network. - self._network, self._network_args = self._configure_network( - network_fn, network_args - ) + self._network = None + self._network_args = network_args + self._configure_network(network_fn) - # To device. - self._network.to(self._device) + # Place network on device (GPU). + self.to_device() + + # Loss and Optimizer placeholders for before loading. + self._criterion = criterion + self.criterion_args = criterion_args + + self._optimizer = optimizer + self.optimizer_args = optimizer_args + + self._lr_scheduler = lr_scheduler + self.lr_scheduler_args = lr_scheduler_args + + def configure_model(self) -> None: + """Configures criterion and optimizers.""" + if not self.is_configured: + self._configure_criterion() + self._configure_optimizers() + + # Prints a summary of the network in terminal. + self.summary() + + # Set this flag to true to prevent the model from configuring again. + self.is_configured = True + + def prepare_data(self) -> None: + """Prepare data for training.""" + # TODO add downloading. + if not self.data_prepared: + # Load train dataset. + train_dataset = self.dataset(train=True, **self.dataset_args["args"]) + + # Set input shape. + self._input_shape = train_dataset.input_shape + + # Split train dataset into a training and validation partition. + dataset_len = len(train_dataset) + train_len = int( + self.dataset_args["train_args"]["train_fraction"] * dataset_len + ) + val_len = dataset_len - train_len + self.train_dataset, self.val_dataset = random_split( + train_dataset, lengths=[train_len, val_len] + ) + + # Load test dataset. + self.test_dataset = self.dataset(train=False, **self.dataset_args["args"]) + + # Set the flag to true to disable ability to load data agian. + self.data_prepared = True - # Configure training objects. - self._criterion = self._configure_criterion(criterion, criterion_args) - self._optimizer, self._lr_scheduler = self._configure_optimizers( - optimizer, optimizer_args, lr_scheduler, lr_scheduler_args + def train_dataloader(self) -> DataLoader: + """Returns data loader for training set.""" + return DataLoader( + self.train_dataset, + batch_size=self.dataset_args["train_args"]["batch_size"], + num_workers=self.dataset_args["train_args"]["num_workers"], + shuffle=True, + pin_memory=True, ) - # Experiment directory. - self.model_dir = None + def val_dataloader(self) -> DataLoader: + """Returns data loader for validation set.""" + return DataLoader( + self.val_dataset, + batch_size=self.dataset_args["train_args"]["batch_size"], + num_workers=self.dataset_args["train_args"]["num_workers"], + shuffle=True, + pin_memory=True, + ) - # Flag for stopping training. - self.stop_training = False + def test_dataloader(self) -> DataLoader: + """Returns data loader for test set.""" + return DataLoader( + self.test_dataset, + batch_size=self.dataset_args["train_args"]["batch_size"], + num_workers=self.dataset_args["train_args"]["num_workers"], + shuffle=False, + pin_memory=True, + ) - def _configure_data_loader( - self, data_loader_args: Optional[Dict] - ) -> Tuple[str, Dict, EmnistMapper]: - """Loads data loader, dataset name, and dataset mapper.""" - if data_loader_args is not None: - data_loaders = fetch_data_loaders(**data_loader_args) - dataset = list(data_loaders.values())[0].dataset - dataset_name = dataset.__name__ - mapper = dataset.mapper - else: - self._mapper = EmnistMapper() - dataset_name = "*" - data_loaders = None - return dataset_name, data_loaders, mapper - - def _configure_network( - self, network_fn: Type[nn.Module], network_args: Optional[Dict] - ) -> Tuple[Type[nn.Module], Dict]: + def _configure_network(self, network_fn: Type[nn.Module]) -> None: """Loads the network.""" # If no network arguemnts are given, load pretrained weights if they exist. - if network_args is None: - network, network_args = self.load_weights(network_fn) + if self._network_args is None: + self.load_weights(network_fn) else: - network = network_fn(**network_args) - return network, network_args + self._network = network_fn(**self._network_args) - def _configure_criterion( - self, criterion: Optional[Callable], criterion_args: Optional[Dict] - ) -> Optional[Callable]: + def _configure_criterion(self) -> None: """Loads the criterion.""" - if criterion is not None: - _criterion = criterion(**criterion_args) - else: - _criterion = None - return _criterion + self._criterion = ( + self._criterion(**self.criterion_args) + if self._criterion is not None + else None + ) - def _configure_optimizers( - self, - optimizer: Optional[Callable], - optimizer_args: Optional[Dict], - lr_scheduler: Optional[Callable], - lr_scheduler_args: Optional[Dict], - ) -> Tuple[Optional[Callable], Optional[Callable]]: + def _configure_optimizers(self,) -> None: """Loads the optimizers.""" - if optimizer is not None: - _optimizer = optimizer(self._network.parameters(), **optimizer_args) + if self._optimizer is not None: + self._optimizer = self._optimizer( + self._network.parameters(), **self.optimizer_args + ) else: - _optimizer = None + self._optimizer = None - if _optimizer and lr_scheduler is not None: - if "OneCycleLR" in str(lr_scheduler): - lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"]) - _lr_scheduler = lr_scheduler(_optimizer, **lr_scheduler_args) + if self._optimizer and self._lr_scheduler is not None: + if "OneCycleLR" in str(self._lr_scheduler): + self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader()) + self._lr_scheduler = self._lr_scheduler( + self._optimizer, **self.lr_scheduler_args + ) else: - _lr_scheduler = None + self._lr_scheduler = None - return _optimizer, _lr_scheduler + if self.swa_args is not None: + self._swa_start = self.swa_args["start"] + self._swa_scheduler = SWALR(self._optimizer, swa_lr=self.swa_args["lr"]) + self._swa_network = AveragedModel(self._network).to(self.device) @property - def __name__(self) -> str: + def name(self) -> str: """Returns the name of the model.""" return self._name @@ -159,7 +240,7 @@ class Model(ABC): return self._input_shape @property - def mapper(self) -> Dict: + def mapper(self) -> EmnistMapper: """Returns the mapper that maps between ints and chars.""" return self._mapper @@ -202,13 +283,24 @@ class Model(ABC): return self._lr_scheduler @property - def data_loaders(self) -> Optional[Dict]: - """Dataloaders.""" - return self._data_loaders + def swa_scheduler(self) -> Optional[Callable]: + """Returns the stochastic weight averaging scheduler.""" + return self._swa_scheduler + + @property + def swa_start(self) -> Optional[Callable]: + """Returns the start epoch of stochastic weight averaging.""" + return self._swa_start @property - def network(self) -> nn.Module: + def swa_network(self) -> Optional[Callable]: + """Returns the stochastic weight averaging network.""" + return self._swa_network + + @property + def network(self) -> Type[nn.Module]: """Neural network.""" + # Returns the SWA network if available. return self._network @property @@ -217,15 +309,27 @@ class Model(ABC): WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True) return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt") - def summary(self) -> None: + def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor: + """Compute the loss.""" + return self.criterion(output, targets) + + def summary( + self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 5 + ) -> None: """Prints a summary of the network architecture.""" - device = re.sub("[^A-Za-z]+", "", self.device) - if self._input_shape is not None: + + if input_shape is not None: + summary(self._network, input_shape, depth=depth, device=self.device) + elif self._input_shape is not None: input_shape = (1,) + tuple(self._input_shape) - summary(self._network, input_shape, device=device) + summary(self._network, input_shape, depth=depth, device=self.device) else: logger.warning("Could not print summary as input shape is not set.") + def to_device(self) -> None: + """Places the network on the device (GPU).""" + self._network.to(self._device) + def _get_state_dict(self) -> Dict: """Get the state dict of the model.""" state = {"model_state": self._network.state_dict()} @@ -236,69 +340,67 @@ class Model(ABC): if self._lr_scheduler is not None: state["scheduler_state"] = self._lr_scheduler.state_dict() + if self._swa_network is not None: + state["swa_network"] = self._swa_network.state_dict() + return state - def load_checkpoint(self, path: Path) -> int: + def load_from_checkpoint(self, checkpoint_path: Path) -> None: """Load a previously saved checkpoint. Args: - path (Path): Path to the experiment with the checkpoint. - - Returns: - epoch (int): The last epoch when the checkpoint was created. + checkpoint_path (Path): Path to the experiment with the checkpoint. """ logger.debug("Loading checkpoint...") - if not path.exists(): - logger.debug("File does not exist {str(path)}") + if not checkpoint_path.exists(): + logger.debug("File does not exist {str(checkpoint_path)}") - checkpoint = torch.load(str(path)) + checkpoint = torch.load(str(checkpoint_path)) self._network.load_state_dict(checkpoint["model_state"]) if self._optimizer is not None: self._optimizer.load_state_dict(checkpoint["optimizer_state"]) - # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs. - # if self._lr_scheduler is not None: - # self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) - - epoch = checkpoint["epoch"] + if self._lr_scheduler is not None: + # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs + # with OneCycleLR. + if self._lr_scheduler.__class__.__name__ != "OneCycleLR": + self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) - return epoch + if self._swa_network is not None: + self._swa_network.load_state_dict(checkpoint["swa_network"]) - def save_checkpoint(self, is_best: bool, epoch: int, val_metric: str) -> None: + def save_checkpoint( + self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str + ) -> None: """Saves a checkpoint of the model. Args: + checkpoint_path (Path): Path to the experiment with the checkpoint. is_best (bool): If it is the currently best model. epoch (int): The epoch of the checkpoint. val_metric (str): Validation metric. - Raises: - ValueError: If the self.model_dir is not set. - """ state = self._get_state_dict() state["is_best"] = is_best state["epoch"] = epoch state["network_args"] = self._network_args - if self.model_dir is None: - raise ValueError("Experiment directory is not set.") - - self.model_dir.mkdir(parents=True, exist_ok=True) + checkpoint_path.mkdir(parents=True, exist_ok=True) logger.debug("Saving checkpoint...") - filepath = str(self.model_dir / "last.pt") + filepath = str(checkpoint_path / "last.pt") torch.save(state, filepath) if is_best: logger.debug( f"Found a new best {val_metric}. Saving best checkpoint and weights." ) - shutil.copyfile(filepath, str(self.model_dir / "best.pt")) + shutil.copyfile(filepath, str(checkpoint_path / "best.pt")) - def load_weights(self, network_fn: Type[nn.Module]) -> Tuple[Type[nn.Module], Dict]: + def load_weights(self, network_fn: Type[nn.Module]) -> None: """Load the network weights.""" logger.debug("Loading network with pretrained weights.") filename = glob(self.weights_filename)[0] @@ -308,13 +410,16 @@ class Model(ABC): ) # Loading state directory. state_dict = torch.load(filename, map_location=torch.device(self._device)) - network_args = state_dict["network_args"] + self._network_args = state_dict["network_args"] weights = state_dict["model_state"] # Initializes the network with trained weights. - network = network_fn(**self._network_args) - network.load_state_dict(weights) - return network, network_args + self._network = network_fn(**self._network_args) + self._network.load_state_dict(weights) + + if "swa_network" in state_dict: + self._swa_network = AveragedModel(self._network).to(self.device) + self._swa_network.load_state_dict(state_dict["swa_network"]) def save_weights(self, path: Path) -> None: """Save the network weights.""" |