summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-20 00:14:27 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-20 00:14:27 +0200
commite181195a699d7fa237f256d90ab4dedffc03d405 (patch)
tree6d8d50731a7267c56f7bf3ed5ecec3990c0e55a5 /src/text_recognizer/models
parent3b06ef615a8db67a03927576e0c12fbfb2501f5f (diff)
Minor bug fixes etc.
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r--src/text_recognizer/models/base.py56
-rw-r--r--src/text_recognizer/models/character_model.py6
-rw-r--r--src/text_recognizer/models/line_ctc_model.py8
3 files changed, 39 insertions, 31 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index d23fe56..caf8065 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -77,9 +77,9 @@ class Model(ABC):
# Stochastic Weight Averaging placeholders.
self.swa_args = swa_args
- self._swa_start = None
self._swa_scheduler = None
self._swa_network = None
+ self._use_swa_model = False
# Experiment directory.
self.model_dir = None
@@ -220,15 +220,24 @@ class Model(ABC):
if self._optimizer and self._lr_scheduler is not None:
if "OneCycleLR" in str(self._lr_scheduler):
self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader())
- self._lr_scheduler = self._lr_scheduler(
- self._optimizer, **self.lr_scheduler_args
- )
- else:
- self._lr_scheduler = None
+
+ # Assume lr scheduler should update at each epoch if not specified.
+ if "interval" not in self.lr_scheduler_args:
+ interval = "epoch"
+ else:
+ interval = self.lr_scheduler_args.pop("interval")
+ self._lr_scheduler = {
+ "lr_scheduler": self._lr_scheduler(
+ self._optimizer, **self.lr_scheduler_args
+ ),
+ "interval": interval,
+ }
if self.swa_args is not None:
- self._swa_start = self.swa_args["start"]
- self._swa_scheduler = SWALR(self._optimizer, swa_lr=self.swa_args["lr"])
+ self._swa_scheduler = {
+ "swa_scheduler": SWALR(self._optimizer, swa_lr=self.swa_args["lr"]),
+ "swa_start": self.swa_args["start"],
+ }
self._swa_network = AveragedModel(self._network).to(self.device)
@property
@@ -280,21 +289,16 @@ class Model(ABC):
return self._optimizer
@property
- def lr_scheduler(self) -> Optional[Callable]:
- """Learning rate scheduler."""
+ def lr_scheduler(self) -> Optional[Dict]:
+ """Returns a directory with the learning rate scheduler."""
return self._lr_scheduler
@property
- def swa_scheduler(self) -> Optional[Callable]:
- """Returns the stochastic weight averaging scheduler."""
+ def swa_scheduler(self) -> Optional[Dict]:
+ """Returns a directory with the stochastic weight averaging scheduler."""
return self._swa_scheduler
@property
- def swa_start(self) -> Optional[Callable]:
- """Returns the start epoch of stochastic weight averaging."""
- return self._swa_start
-
- @property
def swa_network(self) -> Optional[Callable]:
"""Returns the stochastic weight averaging network."""
return self._swa_network
@@ -311,20 +315,32 @@ class Model(ABC):
WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True)
return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt")
+ def use_swa_model(self) -> None:
+ """Set to use predictions from SWA model."""
+ if self.swa_network is not None:
+ self._use_swa_model = True
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Feedforward pass with the network."""
+ if self._use_swa_model:
+ return self.swa_network(x)
+ else:
+ return self.network(x)
+
def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
"""Compute the loss."""
return self.criterion(output, targets)
def summary(
- self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 5
+ self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 3
) -> None:
"""Prints a summary of the network architecture."""
if input_shape is not None:
- summary(self._network, input_shape, depth=depth, device=self.device)
+ summary(self.network, input_shape, depth=depth, device=self.device)
elif self._input_shape is not None:
input_shape = (1,) + tuple(self._input_shape)
- summary(self._network, input_shape, depth=depth, device=self.device)
+ summary(self.network, input_shape, depth=depth, device=self.device)
else:
logger.warning("Could not print summary as input shape is not set.")
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 64ba693..50e94a2 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -75,11 +75,7 @@ class CharacterModel(Model):
# Put the image tensor on the device the model weights are on.
image = image.to(self.device)
- logits = (
- self.swa_network(image)
- if self.swa_network is not None
- else self.network(image)
- )
+ logits = self.forward(image)
prediction = self.softmax(logits.squeeze(0))
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py
index af41f18..16eaed3 100644
--- a/src/text_recognizer/models/line_ctc_model.py
+++ b/src/text_recognizer/models/line_ctc_model.py
@@ -98,16 +98,12 @@ class LineCTCModel(Model):
# Put the image tensor on the device the model weights are on.
image = image.to(self.device)
- log_probs = (
- self.swa_network(image)
- if self.swa_network is not None
- else self.network(image)
- )
+ log_probs = self.forward(image)
raw_pred, _ = greedy_decoder(
predictions=log_probs,
character_mapper=self.mapper,
- blank_label=80,
+ blank_label=79,
collapse_repeated=True,
)