summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/transforms.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-10-22 22:45:58 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-10-22 22:45:58 +0200
commit4d7713746eb936832e84852e90292936b933e87d (patch)
tree2b2519d1d2ce53d4e1390590f52018d55dadbc7c /src/text_recognizer/datasets/transforms.py
parent1b3b8073a19f939d18a0bb85247eb0d99284f7cc (diff)
Transfomer added, many other changes.
Diffstat (limited to 'src/text_recognizer/datasets/transforms.py')
-rw-r--r--src/text_recognizer/datasets/transforms.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index 17231a8..c058972 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -3,6 +3,9 @@ import numpy as np
from PIL import Image
import torch
from torch import Tensor
+from torchvision.transforms import Compose, ToTensor
+
+from text_recognizer.datasets.util import EmnistMapper
class Transpose:
@@ -11,3 +14,33 @@ class Transpose:
def __call__(self, image: Image) -> np.ndarray:
"""Swaps axis."""
return np.array(image).swapaxes(0, 1)
+
+
+class AddTokens:
+ """Adds start of sequence and end of sequence tokens to target tensor."""
+
+ def __init__(self, init_token: str, pad_token: str, eos_token: str,) -> None:
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+ self.emnist_mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ self.pad_value = self.emnist_mapper(self.pad_token)
+ self.sos_value = self.emnist_mapper(self.init_token)
+ self.eos_value = self.emnist_mapper(self.eos_token)
+
+ def __call__(self, target: Tensor) -> Tensor:
+ """Adds a sos token to the begining and a eos token to the end of a target sequence."""
+ dtype, device = target.dtype, target.device
+ sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
+
+ # Find the where padding starts.
+ pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item()
+
+ target[pad_index] = self.eos_value
+
+ target = torch.cat([sos, target], dim=0)
+ return target