summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/base.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-08-03 23:33:34 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-08-03 23:33:34 +0200
commit07dd14116fe1d8148fb614b160245287533620fc (patch)
tree63395d88b17a14ad453c52889fcf541e6cbbdd3e /src/text_recognizer/models/base.py
parent704451318eb6b0b600ab314cb5aabfac82416bda (diff)
Working Emnist lines dataset.
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r--src/text_recognizer/models/base.py84
1 files changed, 54 insertions, 30 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index b78eacb..84a86ca 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -22,7 +22,7 @@ class Model(ABC):
def __init__(
self,
network_fn: Type[nn.Module],
- network_args: Dict,
+ network_args: Optional[Dict] = None,
data_loader: Optional[Callable] = None,
data_loader_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
@@ -38,7 +38,7 @@ class Model(ABC):
Args:
network_fn (Type[nn.Module]): The PyTorch network.
- network_args (Dict): Arguments for the network.
+ network_args (Optional[Dict]): Arguments for the network. Defaults to None.
data_loader (Optional[Callable]): A function that fetches train and val DataLoader.
data_loader_args (Optional[Dict]): Arguments for the DataLoader.
metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
@@ -58,18 +58,14 @@ class Model(ABC):
if data_loader_args is not None:
self._data_loaders = data_loader(**data_loader_args)
dataset_name = self._data_loaders.__name__
+ self._mapping = self._data_loaders.mapping
else:
+ self._mapping = None
dataset_name = "*"
self._data_loaders = None
self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
- # Extract the input shape for the torchsummary.
- if isinstance(network_args["input_size"], int):
- self._input_shape = (1,) + tuple([network_args["input_size"]])
- else:
- self._input_shape = (1,) + tuple(network_args["input_size"])
-
if metrics is not None:
self._metrics = metrics
@@ -80,8 +76,13 @@ class Model(ABC):
self._device = device
# Load network.
- self.network_args = network_args
- self._network = network_fn(**self.network_args)
+ self._network = None
+ self._network_args = network_args
+ # If no network arguemnts are given, load pretrained weights if they exist.
+ if self._network_args is None:
+ self.load_weights(network_fn)
+ else:
+ self._network = network_fn(**self._network_args)
# To device.
self._network.to(self._device)
@@ -104,8 +105,17 @@ class Model(ABC):
lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders("train"))
self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
- # Class mapping.
- self._mapping = None
+ # Extract the input shape for the torchsummary.
+ if isinstance(self._network_args["input_size"], int):
+ self._input_shape = (1,) + tuple([self._network_args["input_size"]])
+ else:
+ self._input_shape = (1,) + tuple(self._network_args["input_size"])
+
+ # Experiment directory.
+ self.model_dir = None
+
+ # Flag for stopping training.
+ self.stop_training = False
@property
def __name__(self) -> str:
@@ -179,8 +189,13 @@ class Model(ABC):
def _get_state_dict(self) -> Dict:
"""Get the state dict of the model."""
state = {"model_state": self._network.state_dict()}
+
if self._optimizer is not None:
state["optimizer_state"] = self._optimizer.state_dict()
+
+ if self._lr_scheduler is not None:
+ state["scheduler_state"] = self._lr_scheduler.state_dict()
+
return state
def load_checkpoint(self, path: Path) -> int:
@@ -203,54 +218,63 @@ class Model(ABC):
if self._optimizer is not None:
self._optimizer.load_state_dict(checkpoint["optimizer_state"])
+ if self._lr_scheduler is not None:
+ self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
+
epoch = checkpoint["epoch"]
return epoch
- def save_checkpoint(
- self, path: Path, is_best: bool, epoch: int, val_metric: str
- ) -> None:
+ def save_checkpoint(self, is_best: bool, epoch: int, val_metric: str) -> None:
"""Saves a checkpoint of the model.
Args:
- path (Path): Path to the experiment folder.
is_best (bool): If it is the currently best model.
epoch (int): The epoch of the checkpoint.
val_metric (str): Validation metric.
+ Raises:
+ ValueError: If the self.model_dir is not set.
+
"""
state = self._get_state_dict()
state["is_best"] = is_best
state["epoch"] = epoch
- state["network_args"] = self.network_args
+ state["network_args"] = self._network_args
- path.mkdir(parents=True, exist_ok=True)
+ if self.model_dir is None:
+ raise ValueError("Experiment directory is not set.")
+
+ self.model_dir.mkdir(parents=True, exist_ok=True)
logger.debug("Saving checkpoint...")
- filepath = str(path / "last.pt")
+ filepath = str(self.model_dir / "last.pt")
torch.save(state, filepath)
if is_best:
logger.debug(
f"Found a new best {val_metric}. Saving best checkpoint and weights."
)
- shutil.copyfile(filepath, str(path / "best.pt"))
+ shutil.copyfile(filepath, str(self.model_dir / "best.pt"))
- def load_weights(self) -> None:
+ def load_weights(self, network_fn: Type[nn.Module]) -> None:
"""Load the network weights."""
- logger.debug("Loading network weights.")
+ logger.debug("Loading network with pretrained weights.")
filename = glob(self.weights_filename)[0]
- weights = torch.load(filename, map_location=torch.device(self._device))[
- "model_state"
- ]
+ if not filename:
+ raise FileNotFoundError(
+ f"Could not find any pretrained weights at {self.weights_filename}"
+ )
+ # Loading state directory.
+ state_dict = torch.load(filename, map_location=torch.device(self._device))
+ self._network_args = state_dict["network_args"]
+ weights = state_dict["model_state"]
+
+ # Initializes the network with trained weights.
+ self._network = network_fn(**self._network_args)
self._network.load_state_dict(weights)
def save_weights(self, path: Path) -> None:
"""Save the network weights."""
logger.debug("Saving the best network weights.")
shutil.copyfile(str(path / "best.pt"), self.weights_filename)
-
- @abstractmethod
- def load_mapping(self) -> None:
- """Loads class mapping from network output to character."""
- ...