diff options
Diffstat (limited to 'jax_optimization_tutorial')
-rw-r--r-- | jax_optimization_tutorial/basic-optimization.ipynb | 144 |
1 files changed, 114 insertions, 30 deletions
diff --git a/jax_optimization_tutorial/basic-optimization.ipynb b/jax_optimization_tutorial/basic-optimization.ipynb index ddac210..809934c 100644 --- a/jax_optimization_tutorial/basic-optimization.ipynb +++ b/jax_optimization_tutorial/basic-optimization.ipynb @@ -11,29 +11,12 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 7, "id": "02e5a88f", "metadata": {}, - "outputs": [ - { - "ename": "RuntimeError", - "evalue": "jaxlib is version 0.1.71, but this version of jax requires version >= 0.3.0.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [1]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m grad, jit, vmap\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m random\n", - "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/__init__.py:37\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m _cloud_tpu_init\n\u001b[1;32m 32\u001b[0m \u001b[38;5;66;03m# flake8: noqa: F401\u001b[39;00m\n\u001b[1;32m 33\u001b[0m \n\u001b[1;32m 34\u001b[0m \u001b[38;5;66;03m# Confusingly there are two things named \"config\": the module and the class.\u001b[39;00m\n\u001b[1;32m 35\u001b[0m \u001b[38;5;66;03m# We want the exported object to be the class, so we first import the module\u001b[39;00m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;66;03m# to make sure a later import doesn't overwrite the class.\u001b[39;00m\n\u001b[0;32m---> 37\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m config \u001b[38;5;28;01mas\u001b[39;00m _config_module\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m _config_module\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mconfig\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 41\u001b[0m config \u001b[38;5;28;01mas\u001b[39;00m config,\n\u001b[1;32m 42\u001b[0m enable_checks \u001b[38;5;28;01mas\u001b[39;00m enable_checks,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 56\u001b[0m transfer_guard_device_to_host \u001b[38;5;28;01mas\u001b[39;00m transfer_guard_device_to_host,\n\u001b[1;32m 57\u001b[0m )\n", - "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/config.py:18\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Copyright 2018 Google LLC\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# Licensed under the Apache License, Version 2.0 (the \"License\");\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 16\u001b[0m \n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# flake8: noqa: F401\u001b[39;00m\n\u001b[0;32m---> 18\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mconfig\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m config\n", - "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/_src/config.py:27\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Any, List, Callable, NamedTuple, Iterator, Optional\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mwarnings\u001b[39;00m\n\u001b[0;32m---> 27\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m lib\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlib\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m jax_jit\n\u001b[1;32m 29\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m lib\u001b[38;5;241m.\u001b[39mxla_extension_version \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m58\u001b[39m:\n", - "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/_src/lib/__init__.py:87\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _jaxlib_version\n\u001b[1;32m 86\u001b[0m version_str \u001b[38;5;241m=\u001b[39m jaxlib\u001b[38;5;241m.\u001b[39mversion\u001b[38;5;241m.\u001b[39m__version__\n\u001b[0;32m---> 87\u001b[0m version \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_jaxlib_version\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 88\u001b[0m \u001b[43m \u001b[49m\u001b[43mjax_version\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mversion\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__version__\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 89\u001b[0m \u001b[43m \u001b[49m\u001b[43mjaxlib_version\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjaxlib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mversion\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__version__\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[43m \u001b[49m\u001b[43mminimum_jaxlib_version\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mversion\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_minimum_jaxlib_version\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;66;03m# Before importing any C compiled modules from jaxlib, first import the CPU\u001b[39;00m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;66;03m# feature guard module to verify that jaxlib was compiled in a way that only\u001b[39;00m\n\u001b[1;32m 96\u001b[0m \u001b[38;5;66;03m# uses instructions that are present on this machine.\u001b[39;00m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjaxlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcpu_feature_guard\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mcpu_feature_guard\u001b[39;00m\n", - "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/_src/lib/__init__.py:76\u001b[0m, in \u001b[0;36mcheck_jaxlib_version\u001b[0;34m(jax_version, jaxlib_version, minimum_jaxlib_version)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _jaxlib_version \u001b[38;5;241m<\u001b[39m _minimum_jaxlib_version:\n\u001b[1;32m 74\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mjaxlib is version \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjaxlib_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, but this version \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mof jax requires version >= \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mminimum_jaxlib_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m---> 76\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(msg)\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _jaxlib_version \u001b[38;5;241m>\u001b[39m _jax_version:\n\u001b[1;32m 79\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mjaxlib version \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjaxlib_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is newer than and \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 80\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mincompatible with jax version \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjax_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Please \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mupdate your jax and/or jaxlib packages.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", - "\u001b[0;31mRuntimeError\u001b[0m: jaxlib is version 0.1.71, but this version of jax requires version >= 0.3.0." - ] - } - ], + "outputs": [], "source": [ - "import jax.numpy as np\n", + "import jax.numpy as jnp\n", "from jax import grad, jit, vmap\n", "from jax import random\n", "from jax import jacfwd, jacrev\n", @@ -43,7 +26,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, + "id": "9bad5173", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "gpu\n" + ] + } + ], + "source": [ + "from jax.lib import xla_bridge\n", + "print(xla_bridge.get_backend().platform)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, "id": "afe9f53c", "metadata": {}, "outputs": [], @@ -80,7 +82,7 @@ ], "source": [ "gradf = grad(f)\n", - "gradf(np.array([2.0]))" + "gradf(jnp.array([2.0]))" ] }, { @@ -121,7 +123,7 @@ } ], "source": [ - "J(np.array([1.0, 2.0]))" + "J(jnp.array([1.0, 2.0]))" ] }, { @@ -163,7 +165,7 @@ } ], "source": [ - "H(np.array([1.0, 2.0]))" + "H(jnp.array([1.0, 2.0]))" ] }, { @@ -183,8 +185,8 @@ "metadata": {}, "outputs": [], "source": [ - "def y(x): return ((x * np.sqrt(12 * x - 36)) / (2 * (x - 3)))\n", - "def L(x): return np.sqrt(x ** 2 + y(x) ** 2)" + "def y(x): return ((x * jnp.sqrt(12 * x - 36)) / (2 * (x - 3)))\n", + "def L(x): return jnp.sqrt(x ** 2 + y(x) ** 2)" ] }, { @@ -214,7 +216,7 @@ "metadata": {}, "outputs": [], "source": [ - "domain = np.linspace(3.0, 5.0, num=50)" + "domain = jnp.linspace(3.0, 5.0, num=50)" ] }, { @@ -405,7 +407,7 @@ "metadata": {}, "outputs": [], "source": [ - "domain = np.linspace(3.0, 5.0, num=50)" + "domain = jnp.linspace(3.0, 5.0, num=50)" ] }, { @@ -954,15 +956,97 @@ "id": "0bf20ea5", "metadata": {}, "source": [ - "# Multivariable MultiConstrained Optimization" + "# Multivariable Multiconstrained Optimization\n", + "\n", + "The general form for of the Lagrangian function can be written as:\n", + "\n", + "$$\n", + " \\mathcal{L}(X)=\\nabla f(X) - \\Sigma_{n=1}^{M}\\nabla g_n(X) = 0\n", + "$$" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "e834f8ba", "metadata": {}, "outputs": [], + "source": [ + "def f(x): return 13*x[0]**2 + 10*x[0]*x[1] + 7*x[1]**2 + x[0] + x[1]\n", + "def g(x): return 2*x[0] - 5*x[1] - 2\n", + "def h(x): return x[0] + x[1] - 1" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1fe85659", + "metadata": {}, + "outputs": [], + "source": [ + "f_min = vmap(f)\n", + "\n", + "def lagrange(l): return f(l[0:2]) - l[2]*g(l[:2]) - l[3]*h(l[:2])\n", + "\n", + "L = jacfwd(lagrange)\n", + "gradL = jacfwd(L)\n", + "\n", + "def solve_lagrange(l): return l - 0.1 * linalg.inv(gradL(l)) @ L(l)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "4497a49e", + "metadata": {}, + "outputs": [], + "source": [ + "domain = random.uniform(key, shape=(300,4), dtype=\"float32\", minval=-4, maxval=1)\n", + "\n", + "v_solve = vmap(solve_lagrange)\n", + "for epoch in range(300):\n", + " domain = v_solve(domain)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "f0e0fa5d", + "metadata": {}, + "outputs": [], + "source": [ + "maximums = f_min(domain)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "ce6ee395", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The minimum is 13.999994277954102, the argmin is (0.9999997615814209,-1.063081178642733e-08), the lagrangian is 2.2857134342193604 and 22.428564071655273\n" + ] + } + ], + "source": [ + "arglist = nanargmax(maximums)\n", + "argmin = domain[arglist]\n", + "minimum = maximums[arglist]\n", + "\n", + "print(\"The minimum is {}, the argmin is ({},{}), the lagrangian is {} and {}\".format(minimum,\n", + " argmin[0], argmin[1], argmin[2], argmin[3]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42d4e0bc", + "metadata": {}, + "outputs": [], "source": [] } ], |