diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-22 23:18:08 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-22 23:18:08 +0200 |
commit | f473456c19558aaf8552df97a51d4e18cc69dfa8 (patch) | |
tree | 0d35ce2410ff623ba5fb433d616d95b67ecf7a98 /src/text_recognizer/models | |
parent | ad3bd52530f4800d4fb05dfef3354921f95513af (diff) |
Working training loop and testing of trained CharacterModel.
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r-- | src/text_recognizer/models/__init__.py | 4 | ||||
-rw-r--r-- | src/text_recognizer/models/base.py | 66 | ||||
-rw-r--r-- | src/text_recognizer/models/character_model.py | 30 | ||||
-rw-r--r-- | src/text_recognizer/models/metrics.py | 2 |
4 files changed, 69 insertions, 33 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index d265dcf..ff10a07 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -1,2 +1,6 @@ """Model modules.""" +from .base import Model from .character_model import CharacterModel +from .metrics import accuracy + +__all__ = ["Model", "CharacterModel", "accuracy"] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index 0cc531a..b78eacb 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -1,9 +1,11 @@ """Abstract Model class for PyTorch neural networks.""" from abc import ABC, abstractmethod +from glob import glob from pathlib import Path +import re import shutil -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple, Type from loguru import logger import torch @@ -19,7 +21,7 @@ class Model(ABC): def __init__( self, - network_fn: Callable, + network_fn: Type[nn.Module], network_args: Dict, data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, @@ -35,7 +37,7 @@ class Model(ABC): """Base class, to be inherited by model for specific type of data. Args: - network_fn (Callable): The PyTorch network. + network_fn (Type[nn.Module]): The PyTorch network. network_args (Dict): Arguments for the network. data_loader (Optional[Callable]): A function that fetches train and val DataLoader. data_loader_args (Optional[Dict]): Arguments for the DataLoader. @@ -57,27 +59,29 @@ class Model(ABC): self._data_loaders = data_loader(**data_loader_args) dataset_name = self._data_loaders.__name__ else: - dataset_name = "" + dataset_name = "*" self._data_loaders = None - self.name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}" + self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}" # Extract the input shape for the torchsummary. - self._input_shape = network_args.pop("input_shape") + if isinstance(network_args["input_size"], int): + self._input_shape = (1,) + tuple([network_args["input_size"]]) + else: + self._input_shape = (1,) + tuple(network_args["input_size"]) if metrics is not None: self._metrics = metrics # Set the device. - if self.device is None: - self._device = torch.device( - "cuda:0" if torch.cuda.is_available() else "cpu" - ) + if device is None: + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self._device = device # Load network. - self._network = network_fn(**network_args) + self.network_args = network_args + self._network = network_fn(**self.network_args) # To device. self._network.to(self._device) @@ -95,13 +99,29 @@ class Model(ABC): # Set learning rate scheduler. self._lr_scheduler = None if lr_scheduler is not None: + # OneCycleLR needs the number of steps in an epoch as an input argument. + if "OneCycleLR" in str(lr_scheduler): + lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders("train")) self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args) + # Class mapping. + self._mapping = None + + @property + def __name__(self) -> str: + """Returns the name of the model.""" + return self._name + @property def input_shape(self) -> Tuple[int, ...]: """The input shape.""" return self._input_shape + @property + def mapping(self) -> Dict: + """Returns the class mapping.""" + return self._mapping + def eval(self) -> None: """Sets the network to evaluation mode.""" self._network.eval() @@ -149,13 +169,14 @@ class Model(ABC): def weights_filename(self) -> str: """Filepath to the network weights.""" WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True) - return str(WEIGHT_DIRNAME / f"{self.name}_weights.pt") + return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt") def summary(self) -> None: """Prints a summary of the network architecture.""" - summary(self._network, self._input_shape, device=self.device) + device = re.sub("[^A-Za-z]+", "", self.device) + summary(self._network, self._input_shape, device=device) - def _get_state(self) -> Dict: + def _get_state_dict(self) -> Dict: """Get the state dict of the model.""" state = {"model_state": self._network.state_dict()} if self._optimizer is not None: @@ -172,6 +193,7 @@ class Model(ABC): epoch (int): The last epoch when the checkpoint was created. """ + logger.debug("Loading checkpoint...") if not path.exists(): logger.debug("File does not exist {str(path)}") @@ -200,6 +222,7 @@ class Model(ABC): state = self._get_state_dict() state["is_best"] = is_best state["epoch"] = epoch + state["network_args"] = self.network_args path.mkdir(parents=True, exist_ok=True) @@ -216,15 +239,18 @@ class Model(ABC): def load_weights(self) -> None: """Load the network weights.""" logger.debug("Loading network weights.") - weights = torch.load(self.weights_filename)["model_state"] + filename = glob(self.weights_filename)[0] + weights = torch.load(filename, map_location=torch.device(self._device))[ + "model_state" + ] self._network.load_state_dict(weights) - def save_weights(self) -> None: + def save_weights(self, path: Path) -> None: """Save the network weights.""" - logger.debug("Saving network weights.") - torch.save({"model_state": self._network.state_dict()}, self.weights_filename) + logger.debug("Saving the best network weights.") + shutil.copyfile(str(path / "best.pt"), self.weights_filename) @abstractmethod - def mapping(self) -> Dict: - """Mapping from network output to class.""" + def load_mapping(self) -> None: + """Loads class mapping from network output to character.""" ... diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index fd69bf2..527fc7d 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -1,5 +1,5 @@ """Defines the CharacterModel class.""" -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple, Type import numpy as np import torch @@ -8,7 +8,6 @@ from torchvision.transforms import ToTensor from text_recognizer.datasets.emnist_dataset import load_emnist_mapping from text_recognizer.models.base import Model -from text_recognizer.networks.mlp import mlp class CharacterModel(Model): @@ -16,8 +15,9 @@ class CharacterModel(Model): def __init__( self, - network_fn: Callable, + network_fn: Type[nn.Module], network_args: Dict, + data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, metrics: Optional[Dict] = None, criterion: Optional[Callable] = None, @@ -33,6 +33,7 @@ class CharacterModel(Model): super().__init__( network_fn, network_args, + data_loader, data_loader_args, metrics, criterion, @@ -43,13 +44,13 @@ class CharacterModel(Model): lr_scheduler_args, device, ) - self.emnist_mapping = self.mapping() - self.eval() + self.load_mapping() + self.tensor_transform = ToTensor() + self.softmax = nn.Softmax(dim=0) - def mapping(self) -> Dict[int, str]: + def load_mapping(self) -> None: """Mapping between integers and classes.""" - mapping = load_emnist_mapping() - return mapping + self._mapping = load_emnist_mapping() def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]: """Character prediction on an image. @@ -61,15 +62,20 @@ class CharacterModel(Model): Tuple[str, float]: The predicted character and the confidence in the prediction. """ + if image.dtype == np.uint8: image = (image / 255).astype(np.float32) # Conver to Pytorch Tensor. - image = ToTensor(image) + image = self.tensor_transform(image) + + with torch.no_grad(): + logits = self.network(image) + + prediction = self.softmax(logits.data.squeeze()) - prediction = self.network(image) - index = torch.argmax(prediction, dim=1) + index = int(torch.argmax(prediction, dim=0)) confidence_of_prediction = prediction[index] - predicted_character = self.emnist_mapping[index] + predicted_character = self._mapping[index] return predicted_character, confidence_of_prediction diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py index e2a30a9..ac8d68e 100644 --- a/src/text_recognizer/models/metrics.py +++ b/src/text_recognizer/models/metrics.py @@ -3,7 +3,7 @@ import torch -def accuracy(outputs: torch.Tensor, labels: torch.Tensro) -> float: +def accuracy(outputs: torch.Tensor, labels: torch.Tensor) -> float: """Computes the accuracy. Args: |