summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r--src/text_recognizer/models/base.py45
-rw-r--r--src/text_recognizer/models/character_model.py8
2 files changed, 28 insertions, 25 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 6d40b49..74fd223 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -53,8 +53,8 @@ class Model(ABC):
"""
- # Fetch data loaders and dataset info.
- dataset_name, self._data_loaders, self._mapper = self._load_data_loader(
+ # Configure data loaders and dataset info.
+ dataset_name, self._data_loaders, self._mapper = self._configure_data_loader(
data_loader_args
)
self._input_shape = self._mapper.input_shape
@@ -70,16 +70,19 @@ class Model(ABC):
else:
self._device = device
- # Load network.
- self._network, self._network_args = self._load_network(network_fn, network_args)
+ # Configure network.
+ self._network, self._network_args = self._configure_network(
+ network_fn, network_args
+ )
# To device.
self._network.to(self._device)
- # Set training objects.
- self._criterion = self._load_criterion(criterion, criterion_args)
- self._optimizer = self._load_optimizer(optimizer, optimizer_args)
- self._lr_scheduler = self._load_lr_scheduler(lr_scheduler, lr_scheduler_args)
+ # Configure training objects.
+ self._criterion = self._configure_criterion(criterion, criterion_args)
+ self._optimizer, self._lr_scheduler = self._configure_optimizers(
+ optimizer, optimizer_args, lr_scheduler, lr_scheduler_args
+ )
# Experiment directory.
self.model_dir = None
@@ -87,7 +90,7 @@ class Model(ABC):
# Flag for stopping training.
self.stop_training = False
- def _load_data_loader(
+ def _configure_data_loader(
self, data_loader_args: Optional[Dict]
) -> Tuple[str, Dict, EmnistMapper]:
"""Loads data loader, dataset name, and dataset mapper."""
@@ -102,7 +105,7 @@ class Model(ABC):
data_loaders = None
return dataset_name, data_loaders, mapper
- def _load_network(
+ def _configure_network(
self, network_fn: Type[nn.Module], network_args: Optional[Dict]
) -> Tuple[Type[nn.Module], Dict]:
"""Loads the network."""
@@ -113,7 +116,7 @@ class Model(ABC):
network = network_fn(**network_args)
return network, network_args
- def _load_criterion(
+ def _configure_criterion(
self, criterion: Optional[Callable], criterion_args: Optional[Dict]
) -> Optional[Callable]:
"""Loads the criterion."""
@@ -123,27 +126,27 @@ class Model(ABC):
_criterion = None
return _criterion
- def _load_optimizer(
- self, optimizer: Optional[Callable], optimizer_args: Optional[Dict]
- ) -> Optional[Callable]:
- """Loads the optimizer."""
+ def _configure_optimizers(
+ self,
+ optimizer: Optional[Callable],
+ optimizer_args: Optional[Dict],
+ lr_scheduler: Optional[Callable],
+ lr_scheduler_args: Optional[Dict],
+ ) -> Tuple[Optional[Callable], Optional[Callable]]:
+ """Loads the optimizers."""
if optimizer is not None:
_optimizer = optimizer(self._network.parameters(), **optimizer_args)
else:
_optimizer = None
- return _optimizer
- def _load_lr_scheduler(
- self, lr_scheduler: Optional[Callable], lr_scheduler_args: Optional[Dict]
- ) -> Optional[Callable]:
- """Loads learning rate scheduler."""
if self._optimizer and lr_scheduler is not None:
if "OneCycleLR" in str(lr_scheduler):
lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"])
_lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
else:
_lr_scheduler = None
- return _lr_scheduler
+
+ return _optimizer, _lr_scheduler
@property
def __name__(self) -> str:
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 0a0ab2d..0fd7afd 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -44,6 +44,7 @@ class CharacterModel(Model):
self.tensor_transform = ToTensor()
self.softmax = nn.Softmax(dim=0)
+ @torch.no_grad()
def predict_on_image(
self, image: Union[np.ndarray, torch.Tensor]
) -> Tuple[str, float]:
@@ -64,10 +65,9 @@ class CharacterModel(Model):
# If the image is an unscaled tensor.
image = image.type("torch.FloatTensor") / 255
- with torch.no_grad():
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
- logits = self.network(image)
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ logits = self.network(image)
prediction = self.softmax(logits.data.squeeze())