summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/transformer_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models/transformer_model.py')
-rw-r--r--src/text_recognizer/models/transformer_model.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py
index 12e497f..3f63053 100644
--- a/src/text_recognizer/models/transformer_model.py
+++ b/src/text_recognizer/models/transformer_model.py
@@ -6,9 +6,9 @@ import torch
from torch import nn
from torch import Tensor
from torch.utils.data import Dataset
-from torchvision.transforms import ToTensor
from text_recognizer.datasets import EmnistMapper
+import text_recognizer.datasets.transforms as transforms
from text_recognizer.models.base import Model
from text_recognizer.networks import greedy_decoder
@@ -60,13 +60,19 @@ class TransformerModel(Model):
eos_token=self.eos_token,
lower=self.lower,
)
- self.tensor_transform = ToTensor()
-
+ self.tensor_transform = transforms.Compose(
+ [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])]
+ )
self.softmax = nn.Softmax(dim=2)
@torch.no_grad()
def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
src = self.network.extract_image_features(image)
+
+ # Added for vqvae transformer.
+ if isinstance(src, Tuple):
+ src = src[0]
+
memory = self.network.encoder(src)
confidence_of_predictions = []