diff options
Diffstat (limited to 'src/text_recognizer/models/line_ctc_model.py')
-rw-r--r-- | src/text_recognizer/models/line_ctc_model.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py index 16eaed3..cdc2d8b 100644 --- a/src/text_recognizer/models/line_ctc_model.py +++ b/src/text_recognizer/models/line_ctc_model.py @@ -51,7 +51,7 @@ class LineCTCModel(Model): self._mapper = EmnistMapper() self.tensor_transform = ToTensor() - def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor: + def criterion(self, output: Tensor, targets: Tensor) -> Tensor: """Computes the CTC loss. Args: @@ -82,11 +82,13 @@ class LineCTCModel(Model): torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) ) - return self.criterion(output, targets, input_lengths, target_lengths) + return self._criterion(output, targets, input_lengths, target_lengths) @torch.no_grad() def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: """Predict on a single input.""" + self.eval() + if image.dtype == np.uint8: # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. image = self.tensor_transform(image) @@ -110,6 +112,6 @@ class LineCTCModel(Model): log_probs, _ = log_probs.max(dim=2) predicted_characters = "".join(raw_pred[0]) - confidence_of_prediction = torch.exp(log_probs.sum()).item() + confidence_of_prediction = torch.exp(-log_probs.sum()).item() return predicted_characters, confidence_of_prediction |