From beeaef529e7c893a3475fe27edc880e283373725 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 8 Nov 2020 12:41:04 +0100 Subject: Trying to get the CNNTransformer to work, but it is hard. --- src/text_recognizer/datasets/transforms.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) (limited to 'src/text_recognizer/datasets/transforms.py') diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py index c058972..8deac7f 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/src/text_recognizer/datasets/transforms.py @@ -3,7 +3,7 @@ import numpy as np from PIL import Image import torch from torch import Tensor -from torchvision.transforms import Compose, ToTensor +from torchvision.transforms import Compose, Resize, ToPILImage, ToTensor from text_recognizer.datasets.util import EmnistMapper @@ -19,28 +19,35 @@ class Transpose: class AddTokens: """Adds start of sequence and end of sequence tokens to target tensor.""" - def __init__(self, init_token: str, pad_token: str, eos_token: str,) -> None: + def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None: self.init_token = init_token self.pad_token = pad_token self.eos_token = eos_token - self.emnist_mapper = EmnistMapper( - init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - ) + if self.init_token is not None: + self.emnist_mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + ) + else: + self.emnist_mapper = EmnistMapper( + pad_token=self.pad_token, eos_token=self.eos_token, + ) self.pad_value = self.emnist_mapper(self.pad_token) - self.sos_value = self.emnist_mapper(self.init_token) self.eos_value = self.emnist_mapper(self.eos_token) def __call__(self, target: Tensor) -> Tensor: """Adds a sos token to the begining and a eos token to the end of a target sequence.""" dtype, device = target.dtype, target.device - sos = torch.tensor([self.sos_value], dtype=dtype, device=device) # Find the where padding starts. pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item() target[pad_index] = self.eos_value - target = torch.cat([sos, target], dim=0) + if self.init_token is not None: + self.sos_value = self.emnist_mapper(self.init_token) + sos = torch.tensor([self.sos_value], dtype=dtype, device=device) + target = torch.cat([sos, target], dim=0) + return target -- cgit v1.2.3-70-g09d2