summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r--src/text_recognizer/models/base.py139
1 files changed, 89 insertions, 50 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 84a86ca..6d40b49 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -12,6 +12,7 @@ import torch
from torch import nn
from torchsummary import summary
+from text_recognizer.datasets import EmnistMapper, fetch_data_loaders
WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
@@ -23,7 +24,6 @@ class Model(ABC):
self,
network_fn: Type[nn.Module],
network_args: Optional[Dict] = None,
- data_loader: Optional[Callable] = None,
data_loader_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
@@ -39,7 +39,6 @@ class Model(ABC):
Args:
network_fn (Type[nn.Module]): The PyTorch network.
network_args (Optional[Dict]): Arguments for the network. Defaults to None.
- data_loader (Optional[Callable]): A function that fetches train and val DataLoader.
data_loader_args (Optional[Dict]): Arguments for the DataLoader.
metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
criterion (Optional[Callable]): The criterion to evaulate the preformance of the network.
@@ -54,15 +53,11 @@ class Model(ABC):
"""
- # Fetch data loaders.
- if data_loader_args is not None:
- self._data_loaders = data_loader(**data_loader_args)
- dataset_name = self._data_loaders.__name__
- self._mapping = self._data_loaders.mapping
- else:
- self._mapping = None
- dataset_name = "*"
- self._data_loaders = None
+ # Fetch data loaders and dataset info.
+ dataset_name, self._data_loaders, self._mapper = self._load_data_loader(
+ data_loader_args
+ )
+ self._input_shape = self._mapper.input_shape
self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
@@ -76,40 +71,15 @@ class Model(ABC):
self._device = device
# Load network.
- self._network = None
- self._network_args = network_args
- # If no network arguemnts are given, load pretrained weights if they exist.
- if self._network_args is None:
- self.load_weights(network_fn)
- else:
- self._network = network_fn(**self._network_args)
+ self._network, self._network_args = self._load_network(network_fn, network_args)
# To device.
self._network.to(self._device)
- # Set criterion.
- self._criterion = None
- if criterion is not None:
- self._criterion = criterion(**criterion_args)
-
- # Set optimizer.
- self._optimizer = None
- if optimizer is not None:
- self._optimizer = optimizer(self._network.parameters(), **optimizer_args)
-
- # Set learning rate scheduler.
- self._lr_scheduler = None
- if lr_scheduler is not None:
- # OneCycleLR needs the number of steps in an epoch as an input argument.
- if "OneCycleLR" in str(lr_scheduler):
- lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders("train"))
- self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
-
- # Extract the input shape for the torchsummary.
- if isinstance(self._network_args["input_size"], int):
- self._input_shape = (1,) + tuple([self._network_args["input_size"]])
- else:
- self._input_shape = (1,) + tuple(self._network_args["input_size"])
+ # Set training objects.
+ self._criterion = self._load_criterion(criterion, criterion_args)
+ self._optimizer = self._load_optimizer(optimizer, optimizer_args)
+ self._lr_scheduler = self._load_lr_scheduler(lr_scheduler, lr_scheduler_args)
# Experiment directory.
self.model_dir = None
@@ -117,6 +87,64 @@ class Model(ABC):
# Flag for stopping training.
self.stop_training = False
+ def _load_data_loader(
+ self, data_loader_args: Optional[Dict]
+ ) -> Tuple[str, Dict, EmnistMapper]:
+ """Loads data loader, dataset name, and dataset mapper."""
+ if data_loader_args is not None:
+ data_loaders = fetch_data_loaders(**data_loader_args)
+ dataset = list(data_loaders.values())[0].dataset
+ dataset_name = dataset.__name__
+ mapper = dataset.mapper
+ else:
+ self._mapper = EmnistMapper()
+ dataset_name = "*"
+ data_loaders = None
+ return dataset_name, data_loaders, mapper
+
+ def _load_network(
+ self, network_fn: Type[nn.Module], network_args: Optional[Dict]
+ ) -> Tuple[Type[nn.Module], Dict]:
+ """Loads the network."""
+ # If no network arguemnts are given, load pretrained weights if they exist.
+ if network_args is None:
+ network, network_args = self.load_weights(network_fn)
+ else:
+ network = network_fn(**network_args)
+ return network, network_args
+
+ def _load_criterion(
+ self, criterion: Optional[Callable], criterion_args: Optional[Dict]
+ ) -> Optional[Callable]:
+ """Loads the criterion."""
+ if criterion is not None:
+ _criterion = criterion(**criterion_args)
+ else:
+ _criterion = None
+ return _criterion
+
+ def _load_optimizer(
+ self, optimizer: Optional[Callable], optimizer_args: Optional[Dict]
+ ) -> Optional[Callable]:
+ """Loads the optimizer."""
+ if optimizer is not None:
+ _optimizer = optimizer(self._network.parameters(), **optimizer_args)
+ else:
+ _optimizer = None
+ return _optimizer
+
+ def _load_lr_scheduler(
+ self, lr_scheduler: Optional[Callable], lr_scheduler_args: Optional[Dict]
+ ) -> Optional[Callable]:
+ """Loads learning rate scheduler."""
+ if self._optimizer and lr_scheduler is not None:
+ if "OneCycleLR" in str(lr_scheduler):
+ lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"])
+ _lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
+ else:
+ _lr_scheduler = None
+ return _lr_scheduler
+
@property
def __name__(self) -> str:
"""Returns the name of the model."""
@@ -128,9 +156,14 @@ class Model(ABC):
return self._input_shape
@property
+ def mapper(self) -> Dict:
+ """Returns the mapper that maps between ints and chars."""
+ return self._mapper
+
+ @property
def mapping(self) -> Dict:
- """Returns the class mapping."""
- return self._mapping
+ """Returns the mapping between network output and Emnist character."""
+ return self._mapper.mapping
def eval(self) -> None:
"""Sets the network to evaluation mode."""
@@ -184,7 +217,11 @@ class Model(ABC):
def summary(self) -> None:
"""Prints a summary of the network architecture."""
device = re.sub("[^A-Za-z]+", "", self.device)
- summary(self._network, self._input_shape, device=device)
+ if self._input_shape is not None:
+ input_shape = (1,) + tuple(self._input_shape)
+ summary(self._network, input_shape, device=device)
+ else:
+ logger.warning("Could not print summary as input shape is not set.")
def _get_state_dict(self) -> Dict:
"""Get the state dict of the model."""
@@ -218,8 +255,9 @@ class Model(ABC):
if self._optimizer is not None:
self._optimizer.load_state_dict(checkpoint["optimizer_state"])
- if self._lr_scheduler is not None:
- self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
+ # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs.
+ # if self._lr_scheduler is not None:
+ # self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
epoch = checkpoint["epoch"]
@@ -257,7 +295,7 @@ class Model(ABC):
)
shutil.copyfile(filepath, str(self.model_dir / "best.pt"))
- def load_weights(self, network_fn: Type[nn.Module]) -> None:
+ def load_weights(self, network_fn: Type[nn.Module]) -> Tuple[Type[nn.Module], Dict]:
"""Load the network weights."""
logger.debug("Loading network with pretrained weights.")
filename = glob(self.weights_filename)[0]
@@ -267,12 +305,13 @@ class Model(ABC):
)
# Loading state directory.
state_dict = torch.load(filename, map_location=torch.device(self._device))
- self._network_args = state_dict["network_args"]
+ network_args = state_dict["network_args"]
weights = state_dict["model_state"]
# Initializes the network with trained weights.
- self._network = network_fn(**self._network_args)
- self._network.load_state_dict(weights)
+ network = network_fn(**self._network_args)
+ network.load_state_dict(weights)
+ return network, network_args
def save_weights(self, path: Path) -> None:
"""Save the network weights."""