diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-20 22:18:35 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-20 22:18:35 +0200 |
commit | 1f459ba19422593de325983040e176f97cf4ffc0 (patch) | |
tree | 89fef442d5dbe0c83253e9566d1762f0704f64e2 /src/text_recognizer/models | |
parent | 95cbdf5bc1cc9639febda23c28d8f464c998b214 (diff) |
A lot of stuff working :D. ResNet implemented!
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r-- | src/text_recognizer/models/base.py | 45 | ||||
-rw-r--r-- | src/text_recognizer/models/character_model.py | 8 |
2 files changed, 28 insertions, 25 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index 6d40b49..74fd223 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -53,8 +53,8 @@ class Model(ABC): """ - # Fetch data loaders and dataset info. - dataset_name, self._data_loaders, self._mapper = self._load_data_loader( + # Configure data loaders and dataset info. + dataset_name, self._data_loaders, self._mapper = self._configure_data_loader( data_loader_args ) self._input_shape = self._mapper.input_shape @@ -70,16 +70,19 @@ class Model(ABC): else: self._device = device - # Load network. - self._network, self._network_args = self._load_network(network_fn, network_args) + # Configure network. + self._network, self._network_args = self._configure_network( + network_fn, network_args + ) # To device. self._network.to(self._device) - # Set training objects. - self._criterion = self._load_criterion(criterion, criterion_args) - self._optimizer = self._load_optimizer(optimizer, optimizer_args) - self._lr_scheduler = self._load_lr_scheduler(lr_scheduler, lr_scheduler_args) + # Configure training objects. + self._criterion = self._configure_criterion(criterion, criterion_args) + self._optimizer, self._lr_scheduler = self._configure_optimizers( + optimizer, optimizer_args, lr_scheduler, lr_scheduler_args + ) # Experiment directory. self.model_dir = None @@ -87,7 +90,7 @@ class Model(ABC): # Flag for stopping training. self.stop_training = False - def _load_data_loader( + def _configure_data_loader( self, data_loader_args: Optional[Dict] ) -> Tuple[str, Dict, EmnistMapper]: """Loads data loader, dataset name, and dataset mapper.""" @@ -102,7 +105,7 @@ class Model(ABC): data_loaders = None return dataset_name, data_loaders, mapper - def _load_network( + def _configure_network( self, network_fn: Type[nn.Module], network_args: Optional[Dict] ) -> Tuple[Type[nn.Module], Dict]: """Loads the network.""" @@ -113,7 +116,7 @@ class Model(ABC): network = network_fn(**network_args) return network, network_args - def _load_criterion( + def _configure_criterion( self, criterion: Optional[Callable], criterion_args: Optional[Dict] ) -> Optional[Callable]: """Loads the criterion.""" @@ -123,27 +126,27 @@ class Model(ABC): _criterion = None return _criterion - def _load_optimizer( - self, optimizer: Optional[Callable], optimizer_args: Optional[Dict] - ) -> Optional[Callable]: - """Loads the optimizer.""" + def _configure_optimizers( + self, + optimizer: Optional[Callable], + optimizer_args: Optional[Dict], + lr_scheduler: Optional[Callable], + lr_scheduler_args: Optional[Dict], + ) -> Tuple[Optional[Callable], Optional[Callable]]: + """Loads the optimizers.""" if optimizer is not None: _optimizer = optimizer(self._network.parameters(), **optimizer_args) else: _optimizer = None - return _optimizer - def _load_lr_scheduler( - self, lr_scheduler: Optional[Callable], lr_scheduler_args: Optional[Dict] - ) -> Optional[Callable]: - """Loads learning rate scheduler.""" if self._optimizer and lr_scheduler is not None: if "OneCycleLR" in str(lr_scheduler): lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"]) _lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args) else: _lr_scheduler = None - return _lr_scheduler + + return _optimizer, _lr_scheduler @property def __name__(self) -> str: diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 0a0ab2d..0fd7afd 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -44,6 +44,7 @@ class CharacterModel(Model): self.tensor_transform = ToTensor() self.softmax = nn.Softmax(dim=0) + @torch.no_grad() def predict_on_image( self, image: Union[np.ndarray, torch.Tensor] ) -> Tuple[str, float]: @@ -64,10 +65,9 @@ class CharacterModel(Model): # If the image is an unscaled tensor. image = image.type("torch.FloatTensor") / 255 - with torch.no_grad(): - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - logits = self.network(image) + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + logits = self.network(image) prediction = self.softmax(logits.data.squeeze()) |