summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/line_ctc_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models/line_ctc_model.py')
-rw-r--r--src/text_recognizer/models/line_ctc_model.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py
index 16eaed3..cdc2d8b 100644
--- a/src/text_recognizer/models/line_ctc_model.py
+++ b/src/text_recognizer/models/line_ctc_model.py
@@ -51,7 +51,7 @@ class LineCTCModel(Model):
self._mapper = EmnistMapper()
self.tensor_transform = ToTensor()
- def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
+ def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
"""Computes the CTC loss.
Args:
@@ -82,11 +82,13 @@ class LineCTCModel(Model):
torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
)
- return self.criterion(output, targets, input_lengths, target_lengths)
+ return self._criterion(output, targets, input_lengths, target_lengths)
@torch.no_grad()
def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
"""Predict on a single input."""
+ self.eval()
+
if image.dtype == np.uint8:
# Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
image = self.tensor_transform(image)
@@ -110,6 +112,6 @@ class LineCTCModel(Model):
log_probs, _ = log_probs.max(dim=2)
predicted_characters = "".join(raw_pred[0])
- confidence_of_prediction = torch.exp(log_probs.sum()).item()
+ confidence_of_prediction = torch.exp(-log_probs.sum()).item()
return predicted_characters, confidence_of_prediction