diff options
Diffstat (limited to 'notebooks/Untitled1.ipynb')
-rw-r--r-- | notebooks/Untitled1.ipynb | 126 |
1 files changed, 125 insertions, 1 deletions
diff --git a/notebooks/Untitled1.ipynb b/notebooks/Untitled1.ipynb index 06129a3..a2d6168 100644 --- a/notebooks/Untitled1.ipynb +++ b/notebooks/Untitled1.ipynb @@ -536,8 +536,132 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "c0e733d8-c17d-46f5-b484-9c74e46d7308", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from typing import Union\n", + "\n", + "import torch\n", + "\n", + "\n", + "def first_appearance(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor:\n", + " \"\"\"Return indices of first appearance of element in x, collapsing along dim.\n", + "\n", + " Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9\n", + "\n", + " Parameters\n", + " ----------\n", + " x\n", + " One or two-dimensional Tensor to search for element.\n", + " element\n", + " Item to search for inside x.\n", + " dim\n", + " Dimension of Tensor to collapse over.\n", + "\n", + " Returns\n", + " -------\n", + " torch.Tensor\n", + " Indices where element occurs in x. If element is not found,\n", + " return length of x along dim. One dimension smaller than x.\n", + "\n", + " Raises\n", + " ------\n", + " ValueError\n", + " if x is not a 1 or 2 dimensional Tensor\n", + "\n", + " Examples\n", + " --------\n", + " >>> first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3)\n", + " tensor([2, 1, 3, 0])\n", + " >>> first_appearance(torch.tensor([1, 2, 3]), 1, dim=0)\n", + " tensor(0)\n", + " \"\"\"\n", + " if x.dim() > 2 or x.dim() == 0:\n", + " raise ValueError(f\"only 1 or 2 dimensional Tensors allowed, got Tensor with dim {x.dim()}\")\n", + " matches = x == element\n", + " first_appearance_mask = (matches.cumsum(dim) == 1) & matches\n", + " does_match, match_index = first_appearance_mask.max(dim)\n", + " first_inds = torch.where(does_match, match_index, x.shape[dim])\n", + " return first_inds" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "26ff2314-b2df-408f-b83a-f5fc903145da", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([2, 3])" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "first_appearance(torch.tensor([[1, 1, 3], [1, 1, 1]]), 3)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e8c9dd16-4917-40bc-8504-084035882ced", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ True, False])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.any(torch.isin(torch.tensor([[1, 1, 3], [1, 1, 1]]), 3), 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "30eb1926-53d2-431a-b7c1-f95919887b84", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([2, 0])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.tensor([[1, 1, 3], [1, 1, 1]]).argmax(dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a02c61b3-c84f-4778-90d0-fe6aafa12ccc", "metadata": {}, "outputs": [], "source": [] |