summaryrefslogtreecommitdiff
path: root/text_recognizer/data/transforms/image.py
blob: f04b3a03dcf94e4b2eca0218eacb19b93075107c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from PIL import Image
import torch
from torch import Tensor
import torchvision.transforms as T


class ImageStem:
    def __init__(self) -> None:
        self.pil_transform = T.Compose([])
        self.pil_to_tensor = T.ToTensor()
        self.torch_transform = torch.nn.Sequential()

    def __call__(self, img: Image) -> Tensor:
        img = self.pil_transform(img)
        img = self.pil_to_tensor(img)
        with torch.no_grad():
            img = self.torch_transform(img)
        return img