diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 12:41:04 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 12:41:04 +0100 |
commit | beeaef529e7c893a3475fe27edc880e283373725 (patch) | |
tree | 59eb72562bf7a5a9470c2586e6280600ad94f1ae /src/text_recognizer/datasets/transforms.py | |
parent | 4d7713746eb936832e84852e90292936b933e87d (diff) |
Trying to get the CNNTransformer to work, but it is hard.
Diffstat (limited to 'src/text_recognizer/datasets/transforms.py')
-rw-r--r-- | src/text_recognizer/datasets/transforms.py | 27 |
1 files changed, 17 insertions, 10 deletions
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py index c058972..8deac7f 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/src/text_recognizer/datasets/transforms.py @@ -3,7 +3,7 @@ import numpy as np from PIL import Image import torch from torch import Tensor -from torchvision.transforms import Compose, ToTensor +from torchvision.transforms import Compose, Resize, ToPILImage, ToTensor from text_recognizer.datasets.util import EmnistMapper @@ -19,28 +19,35 @@ class Transpose: class AddTokens: """Adds start of sequence and end of sequence tokens to target tensor.""" - def __init__(self, init_token: str, pad_token: str, eos_token: str,) -> None: + def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None: self.init_token = init_token self.pad_token = pad_token self.eos_token = eos_token - self.emnist_mapper = EmnistMapper( - init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - ) + if self.init_token is not None: + self.emnist_mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + ) + else: + self.emnist_mapper = EmnistMapper( + pad_token=self.pad_token, eos_token=self.eos_token, + ) self.pad_value = self.emnist_mapper(self.pad_token) - self.sos_value = self.emnist_mapper(self.init_token) self.eos_value = self.emnist_mapper(self.eos_token) def __call__(self, target: Tensor) -> Tensor: """Adds a sos token to the begining and a eos token to the end of a target sequence.""" dtype, device = target.dtype, target.device - sos = torch.tensor([self.sos_value], dtype=dtype, device=device) # Find the where padding starts. pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item() target[pad_index] = self.eos_value - target = torch.cat([sos, target], dim=0) + if self.init_token is not None: + self.sos_value = self.emnist_mapper(self.init_token) + sos = torch.tensor([self.sos_value], dtype=dtype, device=device) + target = torch.cat([sos, target], dim=0) + return target |