summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/vision_transformer_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models/vision_transformer_model.py')
-rw-r--r--src/text_recognizer/models/vision_transformer_model.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/src/text_recognizer/models/vision_transformer_model.py b/src/text_recognizer/models/vision_transformer_model.py
index 20bd4ca..3d36437 100644
--- a/src/text_recognizer/models/vision_transformer_model.py
+++ b/src/text_recognizer/models/vision_transformer_model.py
@@ -53,7 +53,7 @@ class VisionTransformerModel(Model):
if network_args is not None:
self.max_len = network_args["max_len"]
else:
- self.max_len = 128
+ self.max_len = 120
if self._mapper is None:
self._mapper = EmnistMapper(
@@ -73,10 +73,10 @@ class VisionTransformerModel(Model):
confidence_of_predictions = []
trg_indices = [self.mapper(self.init_token)]
- for _ in range(self.max_len):
+ for _ in range(self.max_len - 1):
trg = torch.tensor(trg_indices, device=self.device)[None, :].long()
- trg, trg_mask = self.network.preprocess_target(trg)
- logits = self.network.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ trg = self.network.preprocess_target(trg)
+ logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None)
# Convert logits to probabilities.
probs = self.softmax(logits)
@@ -112,6 +112,8 @@ class VisionTransformerModel(Model):
# Put the image tensor on the device the model weights are on.
image = image.to(self.device)
- predicted_characters, confidence_of_prediction = self._generate_sentence(image)
+ (predicted_characters, confidence_of_prediction,) = self._generate_sentence(
+ image
+ )
return predicted_characters, confidence_of_prediction