diff options
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r-- | text_recognizer/models/transformer.py | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 9b02e78..285b715 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -26,13 +26,10 @@ class LitTransformerModel(LitBaseModel): mapping: Optional[List[str]] = None, ) -> None: super().__init__( - network_args, - optimizer_args, - lr_scheduler_args, - criterion_args, - monitor) + network_args, optimizer_args, lr_scheduler_args, criterion_args, monitor + ) - self.mapping, ignore_tokens = self.configure_mapping(mapping) + self.mapping, ignore_tokens = self.configure_mapping(mapping) self.val_cer = CharacterErrorRate(ignore_tokens) self.test_cer = CharacterErrorRate(ignore_tokens) @@ -47,17 +44,19 @@ class LitTransformerModel(LitBaseModel): start_index = inverse_mapping["<s>"] end_index = inverse_mapping["<e>"] pad_index = inverse_mapping["<p>"] - ignore_tokens = [start_index, end_index, pad_index] + ignore_tokens = [start_index, end_index, pad_index] # TODO: add case for sentence pieces return mapping, ignore_tokens def _log_prediction(self, data: Tensor, pred: Tensor) -> None: """Logs prediction on image with wandb.""" - pred_str = "".join(self.mapping[i] for i in pred[0].tolist() if i != 3) # pad index is 3 + pred_str = "".join( + self.mapping[i] for i in pred[0].tolist() if i != 3 + ) # pad index is 3 try: - self.logger.experiment.log({ - "val_pred_examples": [wandb.Image(data[0], caption=pred_str)] - }) + self.logger.experiment.log( + {"val_pred_examples": [wandb.Image(data[0], caption=pred_str)]} + ) except AttributeError: pass |