summaryrefslogtreecommitdiff
path: root/jax_optimization_tutorial
diff options
context:
space:
mode:
Diffstat (limited to 'jax_optimization_tutorial')
-rw-r--r--jax_optimization_tutorial/basic-optimization.ipynb144
1 files changed, 114 insertions, 30 deletions
diff --git a/jax_optimization_tutorial/basic-optimization.ipynb b/jax_optimization_tutorial/basic-optimization.ipynb
index ddac210..809934c 100644
--- a/jax_optimization_tutorial/basic-optimization.ipynb
+++ b/jax_optimization_tutorial/basic-optimization.ipynb
@@ -11,29 +11,12 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 7,
"id": "02e5a88f",
"metadata": {},
- "outputs": [
- {
- "ename": "RuntimeError",
- "evalue": "jaxlib is version 0.1.71, but this version of jax requires version >= 0.3.0.",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
- "Input \u001b[0;32mIn [1]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m grad, jit, vmap\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m random\n",
- "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/__init__.py:37\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m _cloud_tpu_init\n\u001b[1;32m 32\u001b[0m \u001b[38;5;66;03m# flake8: noqa: F401\u001b[39;00m\n\u001b[1;32m 33\u001b[0m \n\u001b[1;32m 34\u001b[0m \u001b[38;5;66;03m# Confusingly there are two things named \"config\": the module and the class.\u001b[39;00m\n\u001b[1;32m 35\u001b[0m \u001b[38;5;66;03m# We want the exported object to be the class, so we first import the module\u001b[39;00m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;66;03m# to make sure a later import doesn't overwrite the class.\u001b[39;00m\n\u001b[0;32m---> 37\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m config \u001b[38;5;28;01mas\u001b[39;00m _config_module\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m _config_module\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mconfig\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 41\u001b[0m config \u001b[38;5;28;01mas\u001b[39;00m config,\n\u001b[1;32m 42\u001b[0m enable_checks \u001b[38;5;28;01mas\u001b[39;00m enable_checks,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 56\u001b[0m transfer_guard_device_to_host \u001b[38;5;28;01mas\u001b[39;00m transfer_guard_device_to_host,\n\u001b[1;32m 57\u001b[0m )\n",
- "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/config.py:18\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Copyright 2018 Google LLC\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# Licensed under the Apache License, Version 2.0 (the \"License\");\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 16\u001b[0m \n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# flake8: noqa: F401\u001b[39;00m\n\u001b[0;32m---> 18\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mconfig\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m config\n",
- "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/_src/config.py:27\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Any, List, Callable, NamedTuple, Iterator, Optional\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mwarnings\u001b[39;00m\n\u001b[0;32m---> 27\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m lib\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlib\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m jax_jit\n\u001b[1;32m 29\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m lib\u001b[38;5;241m.\u001b[39mxla_extension_version \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m58\u001b[39m:\n",
- "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/_src/lib/__init__.py:87\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _jaxlib_version\n\u001b[1;32m 86\u001b[0m version_str \u001b[38;5;241m=\u001b[39m jaxlib\u001b[38;5;241m.\u001b[39mversion\u001b[38;5;241m.\u001b[39m__version__\n\u001b[0;32m---> 87\u001b[0m version \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_jaxlib_version\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 88\u001b[0m \u001b[43m \u001b[49m\u001b[43mjax_version\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mversion\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__version__\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 89\u001b[0m \u001b[43m \u001b[49m\u001b[43mjaxlib_version\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjaxlib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mversion\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__version__\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[43m \u001b[49m\u001b[43mminimum_jaxlib_version\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mversion\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_minimum_jaxlib_version\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;66;03m# Before importing any C compiled modules from jaxlib, first import the CPU\u001b[39;00m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;66;03m# feature guard module to verify that jaxlib was compiled in a way that only\u001b[39;00m\n\u001b[1;32m 96\u001b[0m \u001b[38;5;66;03m# uses instructions that are present on this machine.\u001b[39;00m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjaxlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcpu_feature_guard\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mcpu_feature_guard\u001b[39;00m\n",
- "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/_src/lib/__init__.py:76\u001b[0m, in \u001b[0;36mcheck_jaxlib_version\u001b[0;34m(jax_version, jaxlib_version, minimum_jaxlib_version)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _jaxlib_version \u001b[38;5;241m<\u001b[39m _minimum_jaxlib_version:\n\u001b[1;32m 74\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mjaxlib is version \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjaxlib_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, but this version \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mof jax requires version >= \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mminimum_jaxlib_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m---> 76\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(msg)\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _jaxlib_version \u001b[38;5;241m>\u001b[39m _jax_version:\n\u001b[1;32m 79\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mjaxlib version \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjaxlib_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is newer than and \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 80\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mincompatible with jax version \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjax_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Please \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mupdate your jax and/or jaxlib packages.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
- "\u001b[0;31mRuntimeError\u001b[0m: jaxlib is version 0.1.71, but this version of jax requires version >= 0.3.0."
- ]
- }
- ],
+ "outputs": [],
"source": [
- "import jax.numpy as np\n",
+ "import jax.numpy as jnp\n",
"from jax import grad, jit, vmap\n",
"from jax import random\n",
"from jax import jacfwd, jacrev\n",
@@ -43,7 +26,26 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 4,
+ "id": "9bad5173",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "gpu\n"
+ ]
+ }
+ ],
+ "source": [
+ "from jax.lib import xla_bridge\n",
+ "print(xla_bridge.get_backend().platform)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
"id": "afe9f53c",
"metadata": {},
"outputs": [],
@@ -80,7 +82,7 @@
],
"source": [
"gradf = grad(f)\n",
- "gradf(np.array([2.0]))"
+ "gradf(jnp.array([2.0]))"
]
},
{
@@ -121,7 +123,7 @@
}
],
"source": [
- "J(np.array([1.0, 2.0]))"
+ "J(jnp.array([1.0, 2.0]))"
]
},
{
@@ -163,7 +165,7 @@
}
],
"source": [
- "H(np.array([1.0, 2.0]))"
+ "H(jnp.array([1.0, 2.0]))"
]
},
{
@@ -183,8 +185,8 @@
"metadata": {},
"outputs": [],
"source": [
- "def y(x): return ((x * np.sqrt(12 * x - 36)) / (2 * (x - 3)))\n",
- "def L(x): return np.sqrt(x ** 2 + y(x) ** 2)"
+ "def y(x): return ((x * jnp.sqrt(12 * x - 36)) / (2 * (x - 3)))\n",
+ "def L(x): return jnp.sqrt(x ** 2 + y(x) ** 2)"
]
},
{
@@ -214,7 +216,7 @@
"metadata": {},
"outputs": [],
"source": [
- "domain = np.linspace(3.0, 5.0, num=50)"
+ "domain = jnp.linspace(3.0, 5.0, num=50)"
]
},
{
@@ -405,7 +407,7 @@
"metadata": {},
"outputs": [],
"source": [
- "domain = np.linspace(3.0, 5.0, num=50)"
+ "domain = jnp.linspace(3.0, 5.0, num=50)"
]
},
{
@@ -954,15 +956,97 @@
"id": "0bf20ea5",
"metadata": {},
"source": [
- "# Multivariable MultiConstrained Optimization"
+ "# Multivariable Multiconstrained Optimization\n",
+ "\n",
+ "The general form for of the Lagrangian function can be written as:\n",
+ "\n",
+ "$$\n",
+ " \\mathcal{L}(X)=\\nabla f(X) - \\Sigma_{n=1}^{M}\\nabla g_n(X) = 0\n",
+ "$$"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 8,
"id": "e834f8ba",
"metadata": {},
"outputs": [],
+ "source": [
+ "def f(x): return 13*x[0]**2 + 10*x[0]*x[1] + 7*x[1]**2 + x[0] + x[1]\n",
+ "def g(x): return 2*x[0] - 5*x[1] - 2\n",
+ "def h(x): return x[0] + x[1] - 1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "1fe85659",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "f_min = vmap(f)\n",
+ "\n",
+ "def lagrange(l): return f(l[0:2]) - l[2]*g(l[:2]) - l[3]*h(l[:2])\n",
+ "\n",
+ "L = jacfwd(lagrange)\n",
+ "gradL = jacfwd(L)\n",
+ "\n",
+ "def solve_lagrange(l): return l - 0.1 * linalg.inv(gradL(l)) @ L(l)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "4497a49e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "domain = random.uniform(key, shape=(300,4), dtype=\"float32\", minval=-4, maxval=1)\n",
+ "\n",
+ "v_solve = vmap(solve_lagrange)\n",
+ "for epoch in range(300):\n",
+ " domain = v_solve(domain)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "f0e0fa5d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "maximums = f_min(domain)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "ce6ee395",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The minimum is 13.999994277954102, the argmin is (0.9999997615814209,-1.063081178642733e-08), the lagrangian is 2.2857134342193604 and 22.428564071655273\n"
+ ]
+ }
+ ],
+ "source": [
+ "arglist = nanargmax(maximums)\n",
+ "argmin = domain[arglist]\n",
+ "minimum = maximums[arglist]\n",
+ "\n",
+ "print(\"The minimum is {}, the argmin is ({},{}), the lagrangian is {} and {}\".format(minimum,\n",
+ " argmin[0], argmin[1], argmin[2], argmin[3]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "42d4e0bc",
+ "metadata": {},
+ "outputs": [],
"source": []
}
],