diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-14 22:15:47 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-14 22:15:47 +0200 |
commit | 3b06ef615a8db67a03927576e0c12fbfb2501f5f (patch) | |
tree | e1c2b1289971c8480327408de46152481e99b539 /src/text_recognizer/models/line_ctc_model.py | |
parent | 2b63fd952bdc9c7c72edd501cbcdbf3231e98f00 (diff) |
Fixed CTC loss.
Diffstat (limited to 'src/text_recognizer/models/line_ctc_model.py')
-rw-r--r-- | src/text_recognizer/models/line_ctc_model.py | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py index 97308a7..af41f18 100644 --- a/src/text_recognizer/models/line_ctc_model.py +++ b/src/text_recognizer/models/line_ctc_model.py @@ -62,12 +62,26 @@ class LineCTCModel(Model): Tensor: The CTC loss. """ + + # Input lengths on the form [T, B] input_lengths = torch.full( size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, ) - target_lengths = torch.full( - size=(output.shape[1],), fill_value=targets.shape[1], dtype=torch.long, + + # Configure target tensors for ctc loss. + targets_ = Tensor([]).to(self.device) + target_lengths = [] + for t in targets: + # Remove padding symbol as it acts as the blank symbol. + t = t[t < 79] + targets_ = torch.cat([targets_, t]) + target_lengths.append(len(t)) + + targets = targets_.type(dtype=torch.long) + target_lengths = ( + torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) ) + return self.criterion(output, targets, input_lengths, target_lengths) @torch.no_grad() @@ -93,7 +107,7 @@ class LineCTCModel(Model): raw_pred, _ = greedy_decoder( predictions=log_probs, character_mapper=self.mapper, - blank_label=79, + blank_label=80, collapse_repeated=True, ) |