diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 12:41:04 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 12:41:04 +0100 |
commit | beeaef529e7c893a3475fe27edc880e283373725 (patch) | |
tree | 59eb72562bf7a5a9470c2586e6280600ad94f1ae /src/text_recognizer/models/vision_transformer_model.py | |
parent | 4d7713746eb936832e84852e90292936b933e87d (diff) |
Trying to get the CNNTransformer to work, but it is hard.
Diffstat (limited to 'src/text_recognizer/models/vision_transformer_model.py')
-rw-r--r-- | src/text_recognizer/models/vision_transformer_model.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/src/text_recognizer/models/vision_transformer_model.py b/src/text_recognizer/models/vision_transformer_model.py index 20bd4ca..3d36437 100644 --- a/src/text_recognizer/models/vision_transformer_model.py +++ b/src/text_recognizer/models/vision_transformer_model.py @@ -53,7 +53,7 @@ class VisionTransformerModel(Model): if network_args is not None: self.max_len = network_args["max_len"] else: - self.max_len = 128 + self.max_len = 120 if self._mapper is None: self._mapper = EmnistMapper( @@ -73,10 +73,10 @@ class VisionTransformerModel(Model): confidence_of_predictions = [] trg_indices = [self.mapper(self.init_token)] - for _ in range(self.max_len): + for _ in range(self.max_len - 1): trg = torch.tensor(trg_indices, device=self.device)[None, :].long() - trg, trg_mask = self.network.preprocess_target(trg) - logits = self.network.decoder(trg=trg, memory=memory, trg_mask=trg_mask) + trg = self.network.preprocess_target(trg) + logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None) # Convert logits to probabilities. probs = self.softmax(logits) @@ -112,6 +112,8 @@ class VisionTransformerModel(Model): # Put the image tensor on the device the model weights are on. image = image.to(self.device) - predicted_characters, confidence_of_prediction = self._generate_sentence(image) + (predicted_characters, confidence_of_prediction,) = self._generate_sentence( + image + ) return predicted_characters, confidence_of_prediction |