From 53677be4ec14854ea4881b0d78730e0414c8dedd Mon Sep 17 00:00:00 2001
From: aktersnurra <gustaf.rydholm@gmail.com>
Date: Sun, 9 Aug 2020 23:24:02 +0200
Subject: Working bash scripts etc.

---
 src/text_recognizer/models/base.py            | 139 +++++++++++++++++---------
 src/text_recognizer/models/character_model.py |  15 +--
 2 files changed, 90 insertions(+), 64 deletions(-)

(limited to 'src/text_recognizer/models')

diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 84a86ca..6d40b49 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -12,6 +12,7 @@ import torch
 from torch import nn
 from torchsummary import summary
 
+from text_recognizer.datasets import EmnistMapper, fetch_data_loaders
 
 WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
 
@@ -23,7 +24,6 @@ class Model(ABC):
         self,
         network_fn: Type[nn.Module],
         network_args: Optional[Dict] = None,
-        data_loader: Optional[Callable] = None,
         data_loader_args: Optional[Dict] = None,
         metrics: Optional[Dict] = None,
         criterion: Optional[Callable] = None,
@@ -39,7 +39,6 @@ class Model(ABC):
         Args:
             network_fn (Type[nn.Module]): The PyTorch network.
             network_args (Optional[Dict]): Arguments for the network. Defaults to None.
-            data_loader (Optional[Callable]): A function that fetches train and val DataLoader.
             data_loader_args (Optional[Dict]):  Arguments for the DataLoader.
             metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
             criterion (Optional[Callable]): The criterion to evaulate the preformance of the network.
@@ -54,15 +53,11 @@ class Model(ABC):
 
         """
 
-        # Fetch data loaders.
-        if data_loader_args is not None:
-            self._data_loaders = data_loader(**data_loader_args)
-            dataset_name = self._data_loaders.__name__
-            self._mapping = self._data_loaders.mapping
-        else:
-            self._mapping = None
-            dataset_name = "*"
-            self._data_loaders = None
+        # Fetch data loaders and dataset info.
+        dataset_name, self._data_loaders, self._mapper = self._load_data_loader(
+            data_loader_args
+        )
+        self._input_shape = self._mapper.input_shape
 
         self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
 
@@ -76,40 +71,15 @@ class Model(ABC):
             self._device = device
 
         # Load network.
-        self._network = None
-        self._network_args = network_args
-        # If no network arguemnts are given, load pretrained weights if they exist.
-        if self._network_args is None:
-            self.load_weights(network_fn)
-        else:
-            self._network = network_fn(**self._network_args)
+        self._network, self._network_args = self._load_network(network_fn, network_args)
 
         # To device.
         self._network.to(self._device)
 
-        # Set criterion.
-        self._criterion = None
-        if criterion is not None:
-            self._criterion = criterion(**criterion_args)
-
-        # Set optimizer.
-        self._optimizer = None
-        if optimizer is not None:
-            self._optimizer = optimizer(self._network.parameters(), **optimizer_args)
-
-        # Set learning rate scheduler.
-        self._lr_scheduler = None
-        if lr_scheduler is not None:
-            # OneCycleLR needs the number of steps in an epoch as an input argument.
-            if "OneCycleLR" in str(lr_scheduler):
-                lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders("train"))
-            self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
-
-        # Extract the input shape for the torchsummary.
-        if isinstance(self._network_args["input_size"], int):
-            self._input_shape = (1,) + tuple([self._network_args["input_size"]])
-        else:
-            self._input_shape = (1,) + tuple(self._network_args["input_size"])
+        # Set training objects.
+        self._criterion = self._load_criterion(criterion, criterion_args)
+        self._optimizer = self._load_optimizer(optimizer, optimizer_args)
+        self._lr_scheduler = self._load_lr_scheduler(lr_scheduler, lr_scheduler_args)
 
         # Experiment directory.
         self.model_dir = None
@@ -117,6 +87,64 @@ class Model(ABC):
         # Flag for stopping training.
         self.stop_training = False
 
+    def _load_data_loader(
+        self, data_loader_args: Optional[Dict]
+    ) -> Tuple[str, Dict, EmnistMapper]:
+        """Loads data loader, dataset name, and dataset mapper."""
+        if data_loader_args is not None:
+            data_loaders = fetch_data_loaders(**data_loader_args)
+            dataset = list(data_loaders.values())[0].dataset
+            dataset_name = dataset.__name__
+            mapper = dataset.mapper
+        else:
+            self._mapper = EmnistMapper()
+            dataset_name = "*"
+            data_loaders = None
+        return dataset_name, data_loaders, mapper
+
+    def _load_network(
+        self, network_fn: Type[nn.Module], network_args: Optional[Dict]
+    ) -> Tuple[Type[nn.Module], Dict]:
+        """Loads the network."""
+        # If no network arguemnts are given, load pretrained weights if they exist.
+        if network_args is None:
+            network, network_args = self.load_weights(network_fn)
+        else:
+            network = network_fn(**network_args)
+        return network, network_args
+
+    def _load_criterion(
+        self, criterion: Optional[Callable], criterion_args: Optional[Dict]
+    ) -> Optional[Callable]:
+        """Loads the criterion."""
+        if criterion is not None:
+            _criterion = criterion(**criterion_args)
+        else:
+            _criterion = None
+        return _criterion
+
+    def _load_optimizer(
+        self, optimizer: Optional[Callable], optimizer_args: Optional[Dict]
+    ) -> Optional[Callable]:
+        """Loads the optimizer."""
+        if optimizer is not None:
+            _optimizer = optimizer(self._network.parameters(), **optimizer_args)
+        else:
+            _optimizer = None
+        return _optimizer
+
+    def _load_lr_scheduler(
+        self, lr_scheduler: Optional[Callable], lr_scheduler_args: Optional[Dict]
+    ) -> Optional[Callable]:
+        """Loads learning rate scheduler."""
+        if self._optimizer and lr_scheduler is not None:
+            if "OneCycleLR" in str(lr_scheduler):
+                lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"])
+            _lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
+        else:
+            _lr_scheduler = None
+        return _lr_scheduler
+
     @property
     def __name__(self) -> str:
         """Returns the name of the model."""
@@ -127,10 +155,15 @@ class Model(ABC):
         """The input shape."""
         return self._input_shape
 
+    @property
+    def mapper(self) -> Dict:
+        """Returns the mapper that maps between ints and chars."""
+        return self._mapper
+
     @property
     def mapping(self) -> Dict:
-        """Returns the class mapping."""
-        return self._mapping
+        """Returns the mapping between network output and Emnist character."""
+        return self._mapper.mapping
 
     def eval(self) -> None:
         """Sets the network to evaluation mode."""
@@ -184,7 +217,11 @@ class Model(ABC):
     def summary(self) -> None:
         """Prints a summary of the network architecture."""
         device = re.sub("[^A-Za-z]+", "", self.device)
-        summary(self._network, self._input_shape, device=device)
+        if self._input_shape is not None:
+            input_shape = (1,) + tuple(self._input_shape)
+            summary(self._network, input_shape, device=device)
+        else:
+            logger.warning("Could not print summary as input shape is not set.")
 
     def _get_state_dict(self) -> Dict:
         """Get the state dict of the model."""
@@ -218,8 +255,9 @@ class Model(ABC):
         if self._optimizer is not None:
             self._optimizer.load_state_dict(checkpoint["optimizer_state"])
 
-        if self._lr_scheduler is not None:
-            self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
+        # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs.
+        # if self._lr_scheduler is not None:
+        #     self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
 
         epoch = checkpoint["epoch"]
 
@@ -257,7 +295,7 @@ class Model(ABC):
             )
             shutil.copyfile(filepath, str(self.model_dir / "best.pt"))
 
-    def load_weights(self, network_fn: Type[nn.Module]) -> None:
+    def load_weights(self, network_fn: Type[nn.Module]) -> Tuple[Type[nn.Module], Dict]:
         """Load the network weights."""
         logger.debug("Loading network with pretrained weights.")
         filename = glob(self.weights_filename)[0]
@@ -267,12 +305,13 @@ class Model(ABC):
             )
         # Loading state directory.
         state_dict = torch.load(filename, map_location=torch.device(self._device))
-        self._network_args = state_dict["network_args"]
+        network_args = state_dict["network_args"]
         weights = state_dict["model_state"]
 
         # Initializes the network with trained weights.
-        self._network = network_fn(**self._network_args)
-        self._network.load_state_dict(weights)
+        network = network_fn(**self._network_args)
+        network.load_state_dict(weights)
+        return network, network_args
 
     def save_weights(self, path: Path) -> None:
         """Save the network weights."""
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index f1dabb7..0a0ab2d 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -6,10 +6,6 @@ import torch
 from torch import nn
 from torchvision.transforms import ToTensor
 
-from text_recognizer.datasets.emnist_dataset import (
-    _augment_emnist_mapping,
-    _load_emnist_essentials,
-)
 from text_recognizer.models.base import Model
 
 
@@ -20,7 +16,6 @@ class CharacterModel(Model):
         self,
         network_fn: Type[nn.Module],
         network_args: Optional[Dict] = None,
-        data_loader: Optional[Callable] = None,
         data_loader_args: Optional[Dict] = None,
         metrics: Optional[Dict] = None,
         criterion: Optional[Callable] = None,
@@ -36,7 +31,6 @@ class CharacterModel(Model):
         super().__init__(
             network_fn,
             network_args,
-            data_loader,
             data_loader_args,
             metrics,
             criterion,
@@ -47,16 +41,9 @@ class CharacterModel(Model):
             lr_scheduler_args,
             device,
         )
-        if self.mapping is None:
-            self.load_mapping()
         self.tensor_transform = ToTensor()
         self.softmax = nn.Softmax(dim=0)
 
-    def load_mapping(self) -> None:
-        """Mapping between integers and classes."""
-        essentials = _load_emnist_essentials()
-        self._mapping = _augment_emnist_mapping(dict(essentials["mapping"]))
-
     def predict_on_image(
         self, image: Union[np.ndarray, torch.Tensor]
     ) -> Tuple[str, float]:
@@ -86,6 +73,6 @@ class CharacterModel(Model):
 
         index = int(torch.argmax(prediction, dim=0))
         confidence_of_prediction = prediction[index]
-        predicted_character = self._mapping[index]
+        predicted_character = self.mapper(index)
 
         return predicted_character, confidence_of_prediction
-- 
cgit v1.2.3-70-g09d2