blob: f04b3a03dcf94e4b2eca0218eacb19b93075107c (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
from PIL import Image
import torch
from torch import Tensor
import torchvision.transforms as T
class ImageStem:
def __init__(self) -> None:
self.pil_transform = T.Compose([])
self.pil_to_tensor = T.ToTensor()
self.torch_transform = torch.nn.Sequential()
def __call__(self, img: Image) -> Tensor:
img = self.pil_transform(img)
img = self.pil_to_tensor(img)
with torch.no_grad():
img = self.torch_transform(img)
return img
|