summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r--src/text_recognizer/models/base.py66
1 files changed, 46 insertions, 20 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 0cc531a..b78eacb 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -1,9 +1,11 @@
"""Abstract Model class for PyTorch neural networks."""
from abc import ABC, abstractmethod
+from glob import glob
from pathlib import Path
+import re
import shutil
-from typing import Callable, Dict, Optional, Tuple
+from typing import Callable, Dict, Optional, Tuple, Type
from loguru import logger
import torch
@@ -19,7 +21,7 @@ class Model(ABC):
def __init__(
self,
- network_fn: Callable,
+ network_fn: Type[nn.Module],
network_args: Dict,
data_loader: Optional[Callable] = None,
data_loader_args: Optional[Dict] = None,
@@ -35,7 +37,7 @@ class Model(ABC):
"""Base class, to be inherited by model for specific type of data.
Args:
- network_fn (Callable): The PyTorch network.
+ network_fn (Type[nn.Module]): The PyTorch network.
network_args (Dict): Arguments for the network.
data_loader (Optional[Callable]): A function that fetches train and val DataLoader.
data_loader_args (Optional[Dict]): Arguments for the DataLoader.
@@ -57,27 +59,29 @@ class Model(ABC):
self._data_loaders = data_loader(**data_loader_args)
dataset_name = self._data_loaders.__name__
else:
- dataset_name = ""
+ dataset_name = "*"
self._data_loaders = None
- self.name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
+ self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
# Extract the input shape for the torchsummary.
- self._input_shape = network_args.pop("input_shape")
+ if isinstance(network_args["input_size"], int):
+ self._input_shape = (1,) + tuple([network_args["input_size"]])
+ else:
+ self._input_shape = (1,) + tuple(network_args["input_size"])
if metrics is not None:
self._metrics = metrics
# Set the device.
- if self.device is None:
- self._device = torch.device(
- "cuda:0" if torch.cuda.is_available() else "cpu"
- )
+ if device is None:
+ self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self._device = device
# Load network.
- self._network = network_fn(**network_args)
+ self.network_args = network_args
+ self._network = network_fn(**self.network_args)
# To device.
self._network.to(self._device)
@@ -95,13 +99,29 @@ class Model(ABC):
# Set learning rate scheduler.
self._lr_scheduler = None
if lr_scheduler is not None:
+ # OneCycleLR needs the number of steps in an epoch as an input argument.
+ if "OneCycleLR" in str(lr_scheduler):
+ lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders("train"))
self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
+ # Class mapping.
+ self._mapping = None
+
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the model."""
+ return self._name
+
@property
def input_shape(self) -> Tuple[int, ...]:
"""The input shape."""
return self._input_shape
+ @property
+ def mapping(self) -> Dict:
+ """Returns the class mapping."""
+ return self._mapping
+
def eval(self) -> None:
"""Sets the network to evaluation mode."""
self._network.eval()
@@ -149,13 +169,14 @@ class Model(ABC):
def weights_filename(self) -> str:
"""Filepath to the network weights."""
WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True)
- return str(WEIGHT_DIRNAME / f"{self.name}_weights.pt")
+ return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt")
def summary(self) -> None:
"""Prints a summary of the network architecture."""
- summary(self._network, self._input_shape, device=self.device)
+ device = re.sub("[^A-Za-z]+", "", self.device)
+ summary(self._network, self._input_shape, device=device)
- def _get_state(self) -> Dict:
+ def _get_state_dict(self) -> Dict:
"""Get the state dict of the model."""
state = {"model_state": self._network.state_dict()}
if self._optimizer is not None:
@@ -172,6 +193,7 @@ class Model(ABC):
epoch (int): The last epoch when the checkpoint was created.
"""
+ logger.debug("Loading checkpoint...")
if not path.exists():
logger.debug("File does not exist {str(path)}")
@@ -200,6 +222,7 @@ class Model(ABC):
state = self._get_state_dict()
state["is_best"] = is_best
state["epoch"] = epoch
+ state["network_args"] = self.network_args
path.mkdir(parents=True, exist_ok=True)
@@ -216,15 +239,18 @@ class Model(ABC):
def load_weights(self) -> None:
"""Load the network weights."""
logger.debug("Loading network weights.")
- weights = torch.load(self.weights_filename)["model_state"]
+ filename = glob(self.weights_filename)[0]
+ weights = torch.load(filename, map_location=torch.device(self._device))[
+ "model_state"
+ ]
self._network.load_state_dict(weights)
- def save_weights(self) -> None:
+ def save_weights(self, path: Path) -> None:
"""Save the network weights."""
- logger.debug("Saving network weights.")
- torch.save({"model_state": self._network.state_dict()}, self.weights_filename)
+ logger.debug("Saving the best network weights.")
+ shutil.copyfile(str(path / "best.pt"), self.weights_filename)
@abstractmethod
- def mapping(self) -> Dict:
- """Mapping from network output to class."""
+ def load_mapping(self) -> None:
+ """Loads class mapping from network output to character."""
...