summaryrefslogtreecommitdiff
path: root/text_recognizer/models/transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r--text_recognizer/models/transformer.py21
1 files changed, 10 insertions, 11 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 9b02e78..285b715 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -26,13 +26,10 @@ class LitTransformerModel(LitBaseModel):
mapping: Optional[List[str]] = None,
) -> None:
super().__init__(
- network_args,
- optimizer_args,
- lr_scheduler_args,
- criterion_args,
- monitor)
+ network_args, optimizer_args, lr_scheduler_args, criterion_args, monitor
+ )
- self.mapping, ignore_tokens = self.configure_mapping(mapping)
+ self.mapping, ignore_tokens = self.configure_mapping(mapping)
self.val_cer = CharacterErrorRate(ignore_tokens)
self.test_cer = CharacterErrorRate(ignore_tokens)
@@ -47,17 +44,19 @@ class LitTransformerModel(LitBaseModel):
start_index = inverse_mapping["<s>"]
end_index = inverse_mapping["<e>"]
pad_index = inverse_mapping["<p>"]
- ignore_tokens = [start_index, end_index, pad_index]
+ ignore_tokens = [start_index, end_index, pad_index]
# TODO: add case for sentence pieces
return mapping, ignore_tokens
def _log_prediction(self, data: Tensor, pred: Tensor) -> None:
"""Logs prediction on image with wandb."""
- pred_str = "".join(self.mapping[i] for i in pred[0].tolist() if i != 3) # pad index is 3
+ pred_str = "".join(
+ self.mapping[i] for i in pred[0].tolist() if i != 3
+ ) # pad index is 3
try:
- self.logger.experiment.log({
- "val_pred_examples": [wandb.Image(data[0], caption=pred_str)]
- })
+ self.logger.experiment.log(
+ {"val_pred_examples": [wandb.Image(data[0], caption=pred_str)]}
+ )
except AttributeError:
pass