diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-09 23:24:02 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-09 23:24:02 +0200 |
commit | 53677be4ec14854ea4881b0d78730e0414c8dedd (patch) | |
tree | 56eaace5e9906c7d408b6a251ca100b5c8b4e991 /src/text_recognizer/models/base.py | |
parent | 125d5da5fb845d03bda91426e172bca7f537584a (diff) |
Working bash scripts etc.
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r-- | src/text_recognizer/models/base.py | 139 |
1 files changed, 89 insertions, 50 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index 84a86ca..6d40b49 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -12,6 +12,7 @@ import torch from torch import nn from torchsummary import summary +from text_recognizer.datasets import EmnistMapper, fetch_data_loaders WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" @@ -23,7 +24,6 @@ class Model(ABC): self, network_fn: Type[nn.Module], network_args: Optional[Dict] = None, - data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, metrics: Optional[Dict] = None, criterion: Optional[Callable] = None, @@ -39,7 +39,6 @@ class Model(ABC): Args: network_fn (Type[nn.Module]): The PyTorch network. network_args (Optional[Dict]): Arguments for the network. Defaults to None. - data_loader (Optional[Callable]): A function that fetches train and val DataLoader. data_loader_args (Optional[Dict]): Arguments for the DataLoader. metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. criterion (Optional[Callable]): The criterion to evaulate the preformance of the network. @@ -54,15 +53,11 @@ class Model(ABC): """ - # Fetch data loaders. - if data_loader_args is not None: - self._data_loaders = data_loader(**data_loader_args) - dataset_name = self._data_loaders.__name__ - self._mapping = self._data_loaders.mapping - else: - self._mapping = None - dataset_name = "*" - self._data_loaders = None + # Fetch data loaders and dataset info. + dataset_name, self._data_loaders, self._mapper = self._load_data_loader( + data_loader_args + ) + self._input_shape = self._mapper.input_shape self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}" @@ -76,40 +71,15 @@ class Model(ABC): self._device = device # Load network. - self._network = None - self._network_args = network_args - # If no network arguemnts are given, load pretrained weights if they exist. - if self._network_args is None: - self.load_weights(network_fn) - else: - self._network = network_fn(**self._network_args) + self._network, self._network_args = self._load_network(network_fn, network_args) # To device. self._network.to(self._device) - # Set criterion. - self._criterion = None - if criterion is not None: - self._criterion = criterion(**criterion_args) - - # Set optimizer. - self._optimizer = None - if optimizer is not None: - self._optimizer = optimizer(self._network.parameters(), **optimizer_args) - - # Set learning rate scheduler. - self._lr_scheduler = None - if lr_scheduler is not None: - # OneCycleLR needs the number of steps in an epoch as an input argument. - if "OneCycleLR" in str(lr_scheduler): - lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders("train")) - self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args) - - # Extract the input shape for the torchsummary. - if isinstance(self._network_args["input_size"], int): - self._input_shape = (1,) + tuple([self._network_args["input_size"]]) - else: - self._input_shape = (1,) + tuple(self._network_args["input_size"]) + # Set training objects. + self._criterion = self._load_criterion(criterion, criterion_args) + self._optimizer = self._load_optimizer(optimizer, optimizer_args) + self._lr_scheduler = self._load_lr_scheduler(lr_scheduler, lr_scheduler_args) # Experiment directory. self.model_dir = None @@ -117,6 +87,64 @@ class Model(ABC): # Flag for stopping training. self.stop_training = False + def _load_data_loader( + self, data_loader_args: Optional[Dict] + ) -> Tuple[str, Dict, EmnistMapper]: + """Loads data loader, dataset name, and dataset mapper.""" + if data_loader_args is not None: + data_loaders = fetch_data_loaders(**data_loader_args) + dataset = list(data_loaders.values())[0].dataset + dataset_name = dataset.__name__ + mapper = dataset.mapper + else: + self._mapper = EmnistMapper() + dataset_name = "*" + data_loaders = None + return dataset_name, data_loaders, mapper + + def _load_network( + self, network_fn: Type[nn.Module], network_args: Optional[Dict] + ) -> Tuple[Type[nn.Module], Dict]: + """Loads the network.""" + # If no network arguemnts are given, load pretrained weights if they exist. + if network_args is None: + network, network_args = self.load_weights(network_fn) + else: + network = network_fn(**network_args) + return network, network_args + + def _load_criterion( + self, criterion: Optional[Callable], criterion_args: Optional[Dict] + ) -> Optional[Callable]: + """Loads the criterion.""" + if criterion is not None: + _criterion = criterion(**criterion_args) + else: + _criterion = None + return _criterion + + def _load_optimizer( + self, optimizer: Optional[Callable], optimizer_args: Optional[Dict] + ) -> Optional[Callable]: + """Loads the optimizer.""" + if optimizer is not None: + _optimizer = optimizer(self._network.parameters(), **optimizer_args) + else: + _optimizer = None + return _optimizer + + def _load_lr_scheduler( + self, lr_scheduler: Optional[Callable], lr_scheduler_args: Optional[Dict] + ) -> Optional[Callable]: + """Loads learning rate scheduler.""" + if self._optimizer and lr_scheduler is not None: + if "OneCycleLR" in str(lr_scheduler): + lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"]) + _lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args) + else: + _lr_scheduler = None + return _lr_scheduler + @property def __name__(self) -> str: """Returns the name of the model.""" @@ -128,9 +156,14 @@ class Model(ABC): return self._input_shape @property + def mapper(self) -> Dict: + """Returns the mapper that maps between ints and chars.""" + return self._mapper + + @property def mapping(self) -> Dict: - """Returns the class mapping.""" - return self._mapping + """Returns the mapping between network output and Emnist character.""" + return self._mapper.mapping def eval(self) -> None: """Sets the network to evaluation mode.""" @@ -184,7 +217,11 @@ class Model(ABC): def summary(self) -> None: """Prints a summary of the network architecture.""" device = re.sub("[^A-Za-z]+", "", self.device) - summary(self._network, self._input_shape, device=device) + if self._input_shape is not None: + input_shape = (1,) + tuple(self._input_shape) + summary(self._network, input_shape, device=device) + else: + logger.warning("Could not print summary as input shape is not set.") def _get_state_dict(self) -> Dict: """Get the state dict of the model.""" @@ -218,8 +255,9 @@ class Model(ABC): if self._optimizer is not None: self._optimizer.load_state_dict(checkpoint["optimizer_state"]) - if self._lr_scheduler is not None: - self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) + # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs. + # if self._lr_scheduler is not None: + # self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) epoch = checkpoint["epoch"] @@ -257,7 +295,7 @@ class Model(ABC): ) shutil.copyfile(filepath, str(self.model_dir / "best.pt")) - def load_weights(self, network_fn: Type[nn.Module]) -> None: + def load_weights(self, network_fn: Type[nn.Module]) -> Tuple[Type[nn.Module], Dict]: """Load the network weights.""" logger.debug("Loading network with pretrained weights.") filename = glob(self.weights_filename)[0] @@ -267,12 +305,13 @@ class Model(ABC): ) # Loading state directory. state_dict = torch.load(filename, map_location=torch.device(self._device)) - self._network_args = state_dict["network_args"] + network_args = state_dict["network_args"] weights = state_dict["model_state"] # Initializes the network with trained weights. - self._network = network_fn(**self._network_args) - self._network.load_state_dict(weights) + network = network_fn(**self._network_args) + network.load_state_dict(weights) + return network, network_args def save_weights(self, path: Path) -> None: """Save the network weights.""" |