diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-24 22:14:17 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-24 22:14:17 +0100 |
commit | 4a54d7e690897dd6e6c719fb908fd371a44c2952 (patch) | |
tree | 04722ac94b9c3960baa5db7939d7ef01dbf535a6 /src/text_recognizer/models/transformer_model.py | |
parent | d691b548cd0b6fc4ea184d64261f633789fee021 (diff) |
Many updates, cool stuff on the way.
Diffstat (limited to 'src/text_recognizer/models/transformer_model.py')
-rw-r--r-- | src/text_recognizer/models/transformer_model.py | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py index 12e497f..3f63053 100644 --- a/src/text_recognizer/models/transformer_model.py +++ b/src/text_recognizer/models/transformer_model.py @@ -6,9 +6,9 @@ import torch from torch import nn from torch import Tensor from torch.utils.data import Dataset -from torchvision.transforms import ToTensor from text_recognizer.datasets import EmnistMapper +import text_recognizer.datasets.transforms as transforms from text_recognizer.models.base import Model from text_recognizer.networks import greedy_decoder @@ -60,13 +60,19 @@ class TransformerModel(Model): eos_token=self.eos_token, lower=self.lower, ) - self.tensor_transform = ToTensor() - + self.tensor_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])] + ) self.softmax = nn.Softmax(dim=2) @torch.no_grad() def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: src = self.network.extract_image_features(image) + + # Added for vqvae transformer. + if isinstance(src, Tuple): + src = src[0] + memory = self.network.encoder(src) confidence_of_predictions = [] |