summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/transforms.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/transforms.py')
-rw-r--r--src/text_recognizer/datasets/transforms.py27
1 files changed, 17 insertions, 10 deletions
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index c058972..8deac7f 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -3,7 +3,7 @@ import numpy as np
from PIL import Image
import torch
from torch import Tensor
-from torchvision.transforms import Compose, ToTensor
+from torchvision.transforms import Compose, Resize, ToPILImage, ToTensor
from text_recognizer.datasets.util import EmnistMapper
@@ -19,28 +19,35 @@ class Transpose:
class AddTokens:
"""Adds start of sequence and end of sequence tokens to target tensor."""
- def __init__(self, init_token: str, pad_token: str, eos_token: str,) -> None:
+ def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None:
self.init_token = init_token
self.pad_token = pad_token
self.eos_token = eos_token
- self.emnist_mapper = EmnistMapper(
- init_token=self.init_token,
- pad_token=self.pad_token,
- eos_token=self.eos_token,
- )
+ if self.init_token is not None:
+ self.emnist_mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ else:
+ self.emnist_mapper = EmnistMapper(
+ pad_token=self.pad_token, eos_token=self.eos_token,
+ )
self.pad_value = self.emnist_mapper(self.pad_token)
- self.sos_value = self.emnist_mapper(self.init_token)
self.eos_value = self.emnist_mapper(self.eos_token)
def __call__(self, target: Tensor) -> Tensor:
"""Adds a sos token to the begining and a eos token to the end of a target sequence."""
dtype, device = target.dtype, target.device
- sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
# Find the where padding starts.
pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item()
target[pad_index] = self.eos_value
- target = torch.cat([sos, target], dim=0)
+ if self.init_token is not None:
+ self.sos_value = self.emnist_mapper(self.init_token)
+ sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
+ target = torch.cat([sos, target], dim=0)
+
return target