From 07dd14116fe1d8148fb614b160245287533620fc Mon Sep 17 00:00:00 2001
From: aktersnurra <gustaf.rydholm@gmail.com>
Date: Mon, 3 Aug 2020 23:33:34 +0200
Subject: Working Emnist lines dataset.

---
 src/text_recognizer/models/base.py            | 84 +++++++++++++++++----------
 src/text_recognizer/models/character_model.py | 32 ++++++----
 2 files changed, 75 insertions(+), 41 deletions(-)

(limited to 'src/text_recognizer/models')

diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index b78eacb..84a86ca 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -22,7 +22,7 @@ class Model(ABC):
     def __init__(
         self,
         network_fn: Type[nn.Module],
-        network_args: Dict,
+        network_args: Optional[Dict] = None,
         data_loader: Optional[Callable] = None,
         data_loader_args: Optional[Dict] = None,
         metrics: Optional[Dict] = None,
@@ -38,7 +38,7 @@ class Model(ABC):
 
         Args:
             network_fn (Type[nn.Module]): The PyTorch network.
-            network_args (Dict): Arguments for the network.
+            network_args (Optional[Dict]): Arguments for the network. Defaults to None.
             data_loader (Optional[Callable]): A function that fetches train and val DataLoader.
             data_loader_args (Optional[Dict]):  Arguments for the DataLoader.
             metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
@@ -58,18 +58,14 @@ class Model(ABC):
         if data_loader_args is not None:
             self._data_loaders = data_loader(**data_loader_args)
             dataset_name = self._data_loaders.__name__
+            self._mapping = self._data_loaders.mapping
         else:
+            self._mapping = None
             dataset_name = "*"
             self._data_loaders = None
 
         self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
 
-        # Extract the input shape for the torchsummary.
-        if isinstance(network_args["input_size"], int):
-            self._input_shape = (1,) + tuple([network_args["input_size"]])
-        else:
-            self._input_shape = (1,) + tuple(network_args["input_size"])
-
         if metrics is not None:
             self._metrics = metrics
 
@@ -80,8 +76,13 @@ class Model(ABC):
             self._device = device
 
         # Load network.
-        self.network_args = network_args
-        self._network = network_fn(**self.network_args)
+        self._network = None
+        self._network_args = network_args
+        # If no network arguemnts are given, load pretrained weights if they exist.
+        if self._network_args is None:
+            self.load_weights(network_fn)
+        else:
+            self._network = network_fn(**self._network_args)
 
         # To device.
         self._network.to(self._device)
@@ -104,8 +105,17 @@ class Model(ABC):
                 lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders("train"))
             self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
 
-        # Class mapping.
-        self._mapping = None
+        # Extract the input shape for the torchsummary.
+        if isinstance(self._network_args["input_size"], int):
+            self._input_shape = (1,) + tuple([self._network_args["input_size"]])
+        else:
+            self._input_shape = (1,) + tuple(self._network_args["input_size"])
+
+        # Experiment directory.
+        self.model_dir = None
+
+        # Flag for stopping training.
+        self.stop_training = False
 
     @property
     def __name__(self) -> str:
@@ -179,8 +189,13 @@ class Model(ABC):
     def _get_state_dict(self) -> Dict:
         """Get the state dict of the model."""
         state = {"model_state": self._network.state_dict()}
+
         if self._optimizer is not None:
             state["optimizer_state"] = self._optimizer.state_dict()
+
+        if self._lr_scheduler is not None:
+            state["scheduler_state"] = self._lr_scheduler.state_dict()
+
         return state
 
     def load_checkpoint(self, path: Path) -> int:
@@ -203,54 +218,63 @@ class Model(ABC):
         if self._optimizer is not None:
             self._optimizer.load_state_dict(checkpoint["optimizer_state"])
 
+        if self._lr_scheduler is not None:
+            self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
+
         epoch = checkpoint["epoch"]
 
         return epoch
 
-    def save_checkpoint(
-        self, path: Path, is_best: bool, epoch: int, val_metric: str
-    ) -> None:
+    def save_checkpoint(self, is_best: bool, epoch: int, val_metric: str) -> None:
         """Saves a checkpoint of the model.
 
         Args:
-            path (Path): Path to the experiment folder.
             is_best (bool): If it is the currently best model.
             epoch (int): The epoch of the checkpoint.
             val_metric (str): Validation metric.
 
+        Raises:
+            ValueError: If the self.model_dir is not set.
+
         """
         state = self._get_state_dict()
         state["is_best"] = is_best
         state["epoch"] = epoch
-        state["network_args"] = self.network_args
+        state["network_args"] = self._network_args
 
-        path.mkdir(parents=True, exist_ok=True)
+        if self.model_dir is None:
+            raise ValueError("Experiment directory is not set.")
+
+        self.model_dir.mkdir(parents=True, exist_ok=True)
 
         logger.debug("Saving checkpoint...")
-        filepath = str(path / "last.pt")
+        filepath = str(self.model_dir / "last.pt")
         torch.save(state, filepath)
 
         if is_best:
             logger.debug(
                 f"Found a new best {val_metric}. Saving best checkpoint and weights."
             )
-            shutil.copyfile(filepath, str(path / "best.pt"))
+            shutil.copyfile(filepath, str(self.model_dir / "best.pt"))
 
-    def load_weights(self) -> None:
+    def load_weights(self, network_fn: Type[nn.Module]) -> None:
         """Load the network weights."""
-        logger.debug("Loading network weights.")
+        logger.debug("Loading network with pretrained weights.")
         filename = glob(self.weights_filename)[0]
-        weights = torch.load(filename, map_location=torch.device(self._device))[
-            "model_state"
-        ]
+        if not filename:
+            raise FileNotFoundError(
+                f"Could not find any pretrained weights at {self.weights_filename}"
+            )
+        # Loading state directory.
+        state_dict = torch.load(filename, map_location=torch.device(self._device))
+        self._network_args = state_dict["network_args"]
+        weights = state_dict["model_state"]
+
+        # Initializes the network with trained weights.
+        self._network = network_fn(**self._network_args)
         self._network.load_state_dict(weights)
 
     def save_weights(self, path: Path) -> None:
         """Save the network weights."""
         logger.debug("Saving the best network weights.")
         shutil.copyfile(str(path / "best.pt"), self.weights_filename)
-
-    @abstractmethod
-    def load_mapping(self) -> None:
-        """Loads class mapping from network output to character."""
-        ...
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 527fc7d..f1dabb7 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -1,12 +1,15 @@
 """Defines the CharacterModel class."""
-from typing import Callable, Dict, Optional, Tuple, Type
+from typing import Callable, Dict, Optional, Tuple, Type, Union
 
 import numpy as np
 import torch
 from torch import nn
 from torchvision.transforms import ToTensor
 
-from text_recognizer.datasets.emnist_dataset import load_emnist_mapping
+from text_recognizer.datasets.emnist_dataset import (
+    _augment_emnist_mapping,
+    _load_emnist_essentials,
+)
 from text_recognizer.models.base import Model
 
 
@@ -16,7 +19,7 @@ class CharacterModel(Model):
     def __init__(
         self,
         network_fn: Type[nn.Module],
-        network_args: Dict,
+        network_args: Optional[Dict] = None,
         data_loader: Optional[Callable] = None,
         data_loader_args: Optional[Dict] = None,
         metrics: Optional[Dict] = None,
@@ -44,19 +47,23 @@ class CharacterModel(Model):
             lr_scheduler_args,
             device,
         )
-        self.load_mapping()
+        if self.mapping is None:
+            self.load_mapping()
         self.tensor_transform = ToTensor()
         self.softmax = nn.Softmax(dim=0)
 
     def load_mapping(self) -> None:
         """Mapping between integers and classes."""
-        self._mapping = load_emnist_mapping()
+        essentials = _load_emnist_essentials()
+        self._mapping = _augment_emnist_mapping(dict(essentials["mapping"]))
 
-    def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]:
+    def predict_on_image(
+        self, image: Union[np.ndarray, torch.Tensor]
+    ) -> Tuple[str, float]:
         """Character prediction on an image.
 
         Args:
-            image (np.ndarray): An image containing a character.
+            image (Union[np.ndarray, torch.Tensor]): An image containing a character.
 
         Returns:
             Tuple[str, float]: The predicted character and the confidence in the prediction.
@@ -64,12 +71,15 @@ class CharacterModel(Model):
         """
 
         if image.dtype == np.uint8:
-            image = (image / 255).astype(np.float32)
-
-        # Conver to Pytorch Tensor.
-        image = self.tensor_transform(image)
+            # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+            image = self.tensor_transform(image)
+        if image.dtype == torch.uint8:
+            # If the image is an unscaled tensor.
+            image = image.type("torch.FloatTensor") / 255
 
         with torch.no_grad():
+            # Put the image tensor on the device the model weights are on.
+            image = image.to(self.device)
             logits = self.network(image)
 
         prediction = self.softmax(logits.data.squeeze())
-- 
cgit v1.2.3-70-g09d2