summaryrefslogtreecommitdiff
path: root/notebooks/Untitled1.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks/Untitled1.ipynb')
-rw-r--r--notebooks/Untitled1.ipynb126
1 files changed, 125 insertions, 1 deletions
diff --git a/notebooks/Untitled1.ipynb b/notebooks/Untitled1.ipynb
index 06129a3..a2d6168 100644
--- a/notebooks/Untitled1.ipynb
+++ b/notebooks/Untitled1.ipynb
@@ -536,8 +536,132 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"id": "c0e733d8-c17d-46f5-b484-9c74e46d7308",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "from typing import Union\n",
+ "\n",
+ "import torch\n",
+ "\n",
+ "\n",
+ "def first_appearance(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor:\n",
+ " \"\"\"Return indices of first appearance of element in x, collapsing along dim.\n",
+ "\n",
+ " Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " x\n",
+ " One or two-dimensional Tensor to search for element.\n",
+ " element\n",
+ " Item to search for inside x.\n",
+ " dim\n",
+ " Dimension of Tensor to collapse over.\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " torch.Tensor\n",
+ " Indices where element occurs in x. If element is not found,\n",
+ " return length of x along dim. One dimension smaller than x.\n",
+ "\n",
+ " Raises\n",
+ " ------\n",
+ " ValueError\n",
+ " if x is not a 1 or 2 dimensional Tensor\n",
+ "\n",
+ " Examples\n",
+ " --------\n",
+ " >>> first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3)\n",
+ " tensor([2, 1, 3, 0])\n",
+ " >>> first_appearance(torch.tensor([1, 2, 3]), 1, dim=0)\n",
+ " tensor(0)\n",
+ " \"\"\"\n",
+ " if x.dim() > 2 or x.dim() == 0:\n",
+ " raise ValueError(f\"only 1 or 2 dimensional Tensors allowed, got Tensor with dim {x.dim()}\")\n",
+ " matches = x == element\n",
+ " first_appearance_mask = (matches.cumsum(dim) == 1) & matches\n",
+ " does_match, match_index = first_appearance_mask.max(dim)\n",
+ " first_inds = torch.where(does_match, match_index, x.shape[dim])\n",
+ " return first_inds"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "26ff2314-b2df-408f-b83a-f5fc903145da",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([2, 3])"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "first_appearance(torch.tensor([[1, 1, 3], [1, 1, 1]]), 3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "e8c9dd16-4917-40bc-8504-084035882ced",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([ True, False])"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.any(torch.isin(torch.tensor([[1, 1, 3], [1, 1, 1]]), 3), 1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "30eb1926-53d2-431a-b7c1-f95919887b84",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([2, 0])"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.tensor([[1, 1, 3], [1, 1, 1]]).argmax(dim=1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a02c61b3-c84f-4778-90d0-fe6aafa12ccc",
"metadata": {},
"outputs": [],
"source": []