diff options
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r-- | src/text_recognizer/models/base.py | 66 |
1 files changed, 50 insertions, 16 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index caf8065..cc44c92 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -6,7 +6,7 @@ import importlib from pathlib import Path import re import shutil -from typing import Callable, Dict, Optional, Tuple, Type +from typing import Callable, Dict, List, Optional, Tuple, Type, Union from loguru import logger import torch @@ -15,6 +15,7 @@ from torch import Tensor from torch.optim.swa_utils import AveragedModel, SWALR from torch.utils.data import DataLoader, Dataset, random_split from torchsummary import summary +from torchvision.transforms import Compose from text_recognizer.datasets import EmnistMapper @@ -128,16 +129,42 @@ class Model(ABC): self._configure_criterion() self._configure_optimizers() - # Prints a summary of the network in terminal. - self.summary() - # Set this flag to true to prevent the model from configuring again. self.is_configured = True + def _configure_transforms(self) -> None: + # Load transforms. + transforms_module = importlib.import_module( + "text_recognizer.datasets.transforms" + ) + if ( + "transform" in self.dataset_args["args"] + and self.dataset_args["args"]["transform"] is not None + ): + transform_ = [] + for t in self.dataset_args["args"]["transform"]: + args = t["args"] or {} + transform_.append(getattr(transforms_module, t["type"])(**args)) + self.dataset_args["args"]["transform"] = Compose(transform_) + + if ( + "target_transform" in self.dataset_args["args"] + and self.dataset_args["args"]["target_transform"] is not None + ): + target_transform_ = [ + torch.tensor, + ] + for t in self.dataset_args["args"]["target_transform"]: + args = t["args"] or {} + target_transform_.append(getattr(transforms_module, t["type"])(**args)) + self.dataset_args["args"]["target_transform"] = Compose(target_transform_) + def prepare_data(self) -> None: """Prepare data for training.""" # TODO add downloading. if not self.data_prepared: + self._configure_transforms() + # Load train dataset. train_dataset = self.dataset(train=True, **self.dataset_args["args"]) train_dataset.load_or_generate_data() @@ -327,20 +354,20 @@ class Model(ABC): else: return self.network(x) - def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor: - """Compute the loss.""" - return self.criterion(output, targets) - def summary( - self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 3 + self, + input_shape: Optional[Union[List, Tuple]] = None, + depth: int = 4, + device: Optional[str] = None, ) -> None: """Prints a summary of the network architecture.""" + device = self.device if device is None else device if input_shape is not None: - summary(self.network, input_shape, depth=depth, device=self.device) + summary(self.network, input_shape, depth=depth, device=device) elif self._input_shape is not None: input_shape = (1,) + tuple(self._input_shape) - summary(self.network, input_shape, depth=depth, device=self.device) + summary(self.network, input_shape, depth=depth, device=device) else: logger.warning("Could not print summary as input shape is not set.") @@ -356,25 +383,29 @@ class Model(ABC): state["optimizer_state"] = self._optimizer.state_dict() if self._lr_scheduler is not None: - state["scheduler_state"] = self._lr_scheduler.state_dict() + state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict() + state["scheduler_interval"] = self._lr_scheduler["interval"] if self._swa_network is not None: state["swa_network"] = self._swa_network.state_dict() return state - def load_from_checkpoint(self, checkpoint_path: Path) -> None: + def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None: """Load a previously saved checkpoint. Args: checkpoint_path (Path): Path to the experiment with the checkpoint. """ + checkpoint_path = Path(checkpoint_path) + self.prepare_data() + self.configure_model() logger.debug("Loading checkpoint...") if not checkpoint_path.exists(): logger.debug("File does not exist {str(checkpoint_path)}") - checkpoint = torch.load(str(checkpoint_path)) + checkpoint = torch.load(str(checkpoint_path), map_location=self.device) self._network.load_state_dict(checkpoint["model_state"]) if self._optimizer is not None: @@ -383,8 +414,11 @@ class Model(ABC): if self._lr_scheduler is not None: # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs # with OneCycleLR. - if self._lr_scheduler.__class__.__name__ != "OneCycleLR": - self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) + if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR": + self._lr_scheduler["lr_scheduler"].load_state_dict( + checkpoint["scheduler_state"] + ) + self._lr_scheduler["interval"] = checkpoint["scheduler_interval"] if self._swa_network is not None: self._swa_network.load_state_dict(checkpoint["swa_network"]) |