1 2 3 4 5 6 7 8 9 10
from typing import Union from torch import Tensor def first_element(x: Tensor, element: Union[int, float], dim: int = 1) -> Tensor: nonz = x == element ind = ((nonz.cumsum(dim) == 1) & nonz).max(dim).indices ind[ind == 0] = x.shape[dim] return ind