blob: 17231a8509990317a7ecc3fd6339ab05e2246492 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
|
"""Transforms for PyTorch datasets."""
import numpy as np
from PIL import Image
import torch
from torch import Tensor
class Transpose:
"""Transposes the EMNIST image to the correct orientation."""
def __call__(self, image: Image) -> np.ndarray:
"""Swaps axis."""
return np.array(image).swapaxes(0, 1)
|