blob: 05c9d940f060a2ef1c7d5faede26d9cbd430ceae (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
import torch
import torchvision.transforms as T
from PIL import Image
from torch import Tensor
class ImageStem:
def __init__(self) -> None:
self.pil_transform = T.Compose([])
self.pil_to_tensor = T.ToTensor()
self.torch_transform = torch.nn.Sequential()
def __call__(self, img: Image) -> Tensor:
img = self.pil_transform(img)
img = self.pil_to_tensor(img)
with torch.no_grad():
img = self.torch_transform(img)
return img
|