diff options
Diffstat (limited to 'src/text_recognizer/models/transformer_model.py')
-rw-r--r-- | src/text_recognizer/models/transformer_model.py | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py index 12e497f..3f63053 100644 --- a/src/text_recognizer/models/transformer_model.py +++ b/src/text_recognizer/models/transformer_model.py @@ -6,9 +6,9 @@ import torch from torch import nn from torch import Tensor from torch.utils.data import Dataset -from torchvision.transforms import ToTensor from text_recognizer.datasets import EmnistMapper +import text_recognizer.datasets.transforms as transforms from text_recognizer.models.base import Model from text_recognizer.networks import greedy_decoder @@ -60,13 +60,19 @@ class TransformerModel(Model): eos_token=self.eos_token, lower=self.lower, ) - self.tensor_transform = ToTensor() - + self.tensor_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])] + ) self.softmax = nn.Softmax(dim=2) @torch.no_grad() def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: src = self.network.extract_image_features(image) + + # Added for vqvae transformer. + if isinstance(src, Tuple): + src = src[0] + memory = self.network.encoder(src) confidence_of_predictions = [] |