summaryrefslogtreecommitdiff
path: root/text_recognizer/models/util.py
blob: cc0df3cec9d1cf3ec0e2eb26861fa15859b2914f (plain)
1
2
3
4
5
6
7
8
9
10
from typing import Union

from torch import Tensor


def first_element(x: Tensor, element: Union[int, float], dim: int = 1) -> Tensor:
    nonz = x == element
    ind = ((nonz.cumsum(dim) == 1) & nonz).max(dim).indices
    ind[ind == 0] = x.shape[dim]
    return ind