summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/line_ctc_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models/line_ctc_model.py')
-rw-r--r--src/text_recognizer/models/line_ctc_model.py20
1 files changed, 17 insertions, 3 deletions
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py
index 97308a7..af41f18 100644
--- a/src/text_recognizer/models/line_ctc_model.py
+++ b/src/text_recognizer/models/line_ctc_model.py
@@ -62,12 +62,26 @@ class LineCTCModel(Model):
Tensor: The CTC loss.
"""
+
+ # Input lengths on the form [T, B]
input_lengths = torch.full(
size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,
)
- target_lengths = torch.full(
- size=(output.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+
+ # Configure target tensors for ctc loss.
+ targets_ = Tensor([]).to(self.device)
+ target_lengths = []
+ for t in targets:
+ # Remove padding symbol as it acts as the blank symbol.
+ t = t[t < 79]
+ targets_ = torch.cat([targets_, t])
+ target_lengths.append(len(t))
+
+ targets = targets_.type(dtype=torch.long)
+ target_lengths = (
+ torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
)
+
return self.criterion(output, targets, input_lengths, target_lengths)
@torch.no_grad()
@@ -93,7 +107,7 @@ class LineCTCModel(Model):
raw_pred, _ = greedy_decoder(
predictions=log_probs,
character_mapper=self.mapper,
- blank_label=79,
+ blank_label=80,
collapse_repeated=True,
)