From 5dc8a7097ab6b4f39f0a3add408e3fd0f131f85b Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 4 Apr 2021 18:21:53 +0200 Subject: black reformatting --- text_recognizer/models/transformer.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) (limited to 'text_recognizer') diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 9b02e78..285b715 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -26,13 +26,10 @@ class LitTransformerModel(LitBaseModel): mapping: Optional[List[str]] = None, ) -> None: super().__init__( - network_args, - optimizer_args, - lr_scheduler_args, - criterion_args, - monitor) + network_args, optimizer_args, lr_scheduler_args, criterion_args, monitor + ) - self.mapping, ignore_tokens = self.configure_mapping(mapping) + self.mapping, ignore_tokens = self.configure_mapping(mapping) self.val_cer = CharacterErrorRate(ignore_tokens) self.test_cer = CharacterErrorRate(ignore_tokens) @@ -47,17 +44,19 @@ class LitTransformerModel(LitBaseModel): start_index = inverse_mapping[""] end_index = inverse_mapping[""] pad_index = inverse_mapping["

"] - ignore_tokens = [start_index, end_index, pad_index] + ignore_tokens = [start_index, end_index, pad_index] # TODO: add case for sentence pieces return mapping, ignore_tokens def _log_prediction(self, data: Tensor, pred: Tensor) -> None: """Logs prediction on image with wandb.""" - pred_str = "".join(self.mapping[i] for i in pred[0].tolist() if i != 3) # pad index is 3 + pred_str = "".join( + self.mapping[i] for i in pred[0].tolist() if i != 3 + ) # pad index is 3 try: - self.logger.experiment.log({ - "val_pred_examples": [wandb.Image(data[0], caption=pred_str)] - }) + self.logger.experiment.log( + {"val_pred_examples": [wandb.Image(data[0], caption=pred_str)]} + ) except AttributeError: pass -- cgit v1.2.3-70-g09d2