summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r--src/text_recognizer/models/base.py45
1 files changed, 24 insertions, 21 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 6d40b49..74fd223 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -53,8 +53,8 @@ class Model(ABC):
"""
- # Fetch data loaders and dataset info.
- dataset_name, self._data_loaders, self._mapper = self._load_data_loader(
+ # Configure data loaders and dataset info.
+ dataset_name, self._data_loaders, self._mapper = self._configure_data_loader(
data_loader_args
)
self._input_shape = self._mapper.input_shape
@@ -70,16 +70,19 @@ class Model(ABC):
else:
self._device = device
- # Load network.
- self._network, self._network_args = self._load_network(network_fn, network_args)
+ # Configure network.
+ self._network, self._network_args = self._configure_network(
+ network_fn, network_args
+ )
# To device.
self._network.to(self._device)
- # Set training objects.
- self._criterion = self._load_criterion(criterion, criterion_args)
- self._optimizer = self._load_optimizer(optimizer, optimizer_args)
- self._lr_scheduler = self._load_lr_scheduler(lr_scheduler, lr_scheduler_args)
+ # Configure training objects.
+ self._criterion = self._configure_criterion(criterion, criterion_args)
+ self._optimizer, self._lr_scheduler = self._configure_optimizers(
+ optimizer, optimizer_args, lr_scheduler, lr_scheduler_args
+ )
# Experiment directory.
self.model_dir = None
@@ -87,7 +90,7 @@ class Model(ABC):
# Flag for stopping training.
self.stop_training = False
- def _load_data_loader(
+ def _configure_data_loader(
self, data_loader_args: Optional[Dict]
) -> Tuple[str, Dict, EmnistMapper]:
"""Loads data loader, dataset name, and dataset mapper."""
@@ -102,7 +105,7 @@ class Model(ABC):
data_loaders = None
return dataset_name, data_loaders, mapper
- def _load_network(
+ def _configure_network(
self, network_fn: Type[nn.Module], network_args: Optional[Dict]
) -> Tuple[Type[nn.Module], Dict]:
"""Loads the network."""
@@ -113,7 +116,7 @@ class Model(ABC):
network = network_fn(**network_args)
return network, network_args
- def _load_criterion(
+ def _configure_criterion(
self, criterion: Optional[Callable], criterion_args: Optional[Dict]
) -> Optional[Callable]:
"""Loads the criterion."""
@@ -123,27 +126,27 @@ class Model(ABC):
_criterion = None
return _criterion
- def _load_optimizer(
- self, optimizer: Optional[Callable], optimizer_args: Optional[Dict]
- ) -> Optional[Callable]:
- """Loads the optimizer."""
+ def _configure_optimizers(
+ self,
+ optimizer: Optional[Callable],
+ optimizer_args: Optional[Dict],
+ lr_scheduler: Optional[Callable],
+ lr_scheduler_args: Optional[Dict],
+ ) -> Tuple[Optional[Callable], Optional[Callable]]:
+ """Loads the optimizers."""
if optimizer is not None:
_optimizer = optimizer(self._network.parameters(), **optimizer_args)
else:
_optimizer = None
- return _optimizer
- def _load_lr_scheduler(
- self, lr_scheduler: Optional[Callable], lr_scheduler_args: Optional[Dict]
- ) -> Optional[Callable]:
- """Loads learning rate scheduler."""
if self._optimizer and lr_scheduler is not None:
if "OneCycleLR" in str(lr_scheduler):
lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"])
_lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
else:
_lr_scheduler = None
- return _lr_scheduler
+
+ return _optimizer, _lr_scheduler
@property
def __name__(self) -> str: