summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 18:21:53 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 18:21:53 +0200
commit5dc8a7097ab6b4f39f0a3add408e3fd0f131f85b (patch)
tree6c114608e7281b66d08435c60367979768c266b9
parent03dae09c63f2079f37bbf25fd9ded6f20f1490da (diff)
black reformatting
-rw-r--r--text_recognizer/models/transformer.py21
1 files changed, 10 insertions, 11 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 9b02e78..285b715 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -26,13 +26,10 @@ class LitTransformerModel(LitBaseModel):
mapping: Optional[List[str]] = None,
) -> None:
super().__init__(
- network_args,
- optimizer_args,
- lr_scheduler_args,
- criterion_args,
- monitor)
+ network_args, optimizer_args, lr_scheduler_args, criterion_args, monitor
+ )
- self.mapping, ignore_tokens = self.configure_mapping(mapping)
+ self.mapping, ignore_tokens = self.configure_mapping(mapping)
self.val_cer = CharacterErrorRate(ignore_tokens)
self.test_cer = CharacterErrorRate(ignore_tokens)
@@ -47,17 +44,19 @@ class LitTransformerModel(LitBaseModel):
start_index = inverse_mapping["<s>"]
end_index = inverse_mapping["<e>"]
pad_index = inverse_mapping["<p>"]
- ignore_tokens = [start_index, end_index, pad_index]
+ ignore_tokens = [start_index, end_index, pad_index]
# TODO: add case for sentence pieces
return mapping, ignore_tokens
def _log_prediction(self, data: Tensor, pred: Tensor) -> None:
"""Logs prediction on image with wandb."""
- pred_str = "".join(self.mapping[i] for i in pred[0].tolist() if i != 3) # pad index is 3
+ pred_str = "".join(
+ self.mapping[i] for i in pred[0].tolist() if i != 3
+ ) # pad index is 3
try:
- self.logger.experiment.log({
- "val_pred_examples": [wandb.Image(data[0], caption=pred_str)]
- })
+ self.logger.experiment.log(
+ {"val_pred_examples": [wandb.Image(data[0], caption=pred_str)]}
+ )
except AttributeError:
pass