{ "cells": [ { "cell_type": "markdown", "id": "f48669df", "metadata": {}, "source": [ "# Solving Optimization Problems with JAX\n", "Source: https://medium.com/swlh/solving-optimization-problems-with-jax-98376508bd4f" ] }, { "cell_type": "code", "execution_count": 1, "id": "02e5a88f", "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "jaxlib is version 0.1.71, but this version of jax requires version >= 0.3.0.", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "Input \u001b[0;32mIn [1]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m grad, jit, vmap\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m random\n", "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/__init__.py:37\u001b[0m, in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m _cloud_tpu_init\n\u001b[1;32m 32\u001b[0m \u001b[38;5;66;03m# flake8: noqa: F401\u001b[39;00m\n\u001b[1;32m 33\u001b[0m \n\u001b[1;32m 34\u001b[0m \u001b[38;5;66;03m# Confusingly there are two things named \"config\": the module and the class.\u001b[39;00m\n\u001b[1;32m 35\u001b[0m \u001b[38;5;66;03m# We want the exported object to be the class, so we first import the module\u001b[39;00m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;66;03m# to make sure a later import doesn't overwrite the class.\u001b[39;00m\n\u001b[0;32m---> 37\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m config \u001b[38;5;28;01mas\u001b[39;00m _config_module\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m _config_module\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mconfig\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 41\u001b[0m config \u001b[38;5;28;01mas\u001b[39;00m config,\n\u001b[1;32m 42\u001b[0m enable_checks \u001b[38;5;28;01mas\u001b[39;00m enable_checks,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 56\u001b[0m transfer_guard_device_to_host \u001b[38;5;28;01mas\u001b[39;00m transfer_guard_device_to_host,\n\u001b[1;32m 57\u001b[0m )\n", "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/config.py:18\u001b[0m, in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Copyright 2018 Google LLC\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# Licensed under the Apache License, Version 2.0 (the \"License\");\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 16\u001b[0m \n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# flake8: noqa: F401\u001b[39;00m\n\u001b[0;32m---> 18\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mconfig\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m config\n", "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/_src/config.py:27\u001b[0m, in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Any, List, Callable, NamedTuple, Iterator, Optional\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mwarnings\u001b[39;00m\n\u001b[0;32m---> 27\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m lib\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_src\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlib\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m jax_jit\n\u001b[1;32m 29\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m lib\u001b[38;5;241m.\u001b[39mxla_extension_version \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m58\u001b[39m:\n", "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/_src/lib/__init__.py:87\u001b[0m, in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _jaxlib_version\n\u001b[1;32m 86\u001b[0m version_str \u001b[38;5;241m=\u001b[39m jaxlib\u001b[38;5;241m.\u001b[39mversion\u001b[38;5;241m.\u001b[39m__version__\n\u001b[0;32m---> 87\u001b[0m version \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_jaxlib_version\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 88\u001b[0m \u001b[43m \u001b[49m\u001b[43mjax_version\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mversion\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__version__\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 89\u001b[0m \u001b[43m \u001b[49m\u001b[43mjaxlib_version\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjaxlib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mversion\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__version__\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[43m \u001b[49m\u001b[43mminimum_jaxlib_version\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mversion\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_minimum_jaxlib_version\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;66;03m# Before importing any C compiled modules from jaxlib, first import the CPU\u001b[39;00m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;66;03m# feature guard module to verify that jaxlib was compiled in a way that only\u001b[39;00m\n\u001b[1;32m 96\u001b[0m \u001b[38;5;66;03m# uses instructions that are present on this machine.\u001b[39;00m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjaxlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcpu_feature_guard\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mcpu_feature_guard\u001b[39;00m\n", "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/jax-optimization-tutorial--qTljXSK-py3.9/lib/python3.9/site-packages/jax/_src/lib/__init__.py:76\u001b[0m, in \u001b[0;36mcheck_jaxlib_version\u001b[0;34m(jax_version, jaxlib_version, minimum_jaxlib_version)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _jaxlib_version \u001b[38;5;241m<\u001b[39m _minimum_jaxlib_version:\n\u001b[1;32m 74\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mjaxlib is version \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjaxlib_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, but this version \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mof jax requires version >= \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mminimum_jaxlib_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m---> 76\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(msg)\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _jaxlib_version \u001b[38;5;241m>\u001b[39m _jax_version:\n\u001b[1;32m 79\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mjaxlib version \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjaxlib_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is newer than and \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 80\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mincompatible with jax version \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjax_version\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Please \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mupdate your jax and/or jaxlib packages.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", "\u001b[0;31mRuntimeError\u001b[0m: jaxlib is version 0.1.71, but this version of jax requires version >= 0.3.0." ] } ], "source": [ "import jax.numpy as np\n", "from jax import grad, jit, vmap\n", "from jax import random\n", "from jax import jacfwd, jacrev\n", "from jax.numpy import linalg\n", "from numpy import nanargmin, nanargmax" ] }, { "cell_type": "code", "execution_count": null, "id": "afe9f53c", "metadata": {}, "outputs": [], "source": [ "key = random.PRNGKey(42)" ] }, { "cell_type": "code", "execution_count": 3, "id": "b8bc9bbe", "metadata": {}, "outputs": [], "source": [ "def f(x): return 3 * x[0] ** 2" ] }, { "cell_type": "code", "execution_count": 4, "id": "aebb2724", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([12.], dtype=float32)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gradf = grad(f)\n", "gradf(np.array([2.0]))" ] }, { "cell_type": "code", "execution_count": 5, "id": "9d891860", "metadata": {}, "outputs": [], "source": [ "def circle(x): return x[0] ** 2 + x[1] ** 2" ] }, { "cell_type": "code", "execution_count": 6, "id": "6a70bd20", "metadata": {}, "outputs": [], "source": [ "J = jacfwd(circle)" ] }, { "cell_type": "code", "execution_count": 7, "id": "dff71dc6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([2., 4.], dtype=float32)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "J(np.array([1.0, 2.0]))" ] }, { "cell_type": "code", "execution_count": 14, "id": "b90095a4", "metadata": {}, "outputs": [], "source": [ "def hessian(f): return jacfwd(jacrev(f))" ] }, { "cell_type": "code", "execution_count": 15, "id": "eb4124f0", "metadata": {}, "outputs": [], "source": [ "H = hessian(circle)" ] }, { "cell_type": "code", "execution_count": 16, "id": "a6d57fb1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[2., 0.],\n", " [0., 2.]], dtype=float32)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "H(np.array([1.0, 2.0]))" ] }, { "cell_type": "code", "execution_count": 17, "id": "b50fe2a5", "metadata": {}, "outputs": [], "source": [ "# Paper problem" ] }, { "cell_type": "code", "execution_count": 18, "id": "89a1d95d", "metadata": {}, "outputs": [], "source": [ "def y(x): return ((x * np.sqrt(12 * x - 36)) / (2 * (x - 3)))\n", "def L(x): return np.sqrt(x ** 2 + y(x) ** 2)" ] }, { "cell_type": "code", "execution_count": 19, "id": "402c0f64", "metadata": {}, "outputs": [], "source": [ "gradL = grad(L)" ] }, { "cell_type": "code", "execution_count": 20, "id": "5bf3e5c2", "metadata": {}, "outputs": [], "source": [ "def minGD(x): return x - 0.01 * gradL(x)" ] }, { "cell_type": "code", "execution_count": 21, "id": "e340a396", "metadata": {}, "outputs": [], "source": [ "domain = np.linspace(3.0, 5.0, num=50)" ] }, { "cell_type": "code", "execution_count": 22, "id": "1d054e99", "metadata": {}, "outputs": [], "source": [ "vfunGD = vmap(minGD)" ] }, { "cell_type": "code", "execution_count": 25, "id": "33a48ebf", "metadata": {}, "outputs": [], "source": [ "for epoch in range(50):\n", " domain = vfunGD(domain)" ] }, { "cell_type": "code", "execution_count": 28, "id": "d698e29a", "metadata": {}, "outputs": [], "source": [ "minfun = vmap(L)\n", "minimums = minfun(domain)" ] }, { "cell_type": "code", "execution_count": 29, "id": "eed252b6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([ nan, 8.337169 , 7.8126216, 7.8554583, 7.8657427,\n", " 7.8686085, 7.869442 , 7.869495 , 7.8691173, 7.8684244,\n", " 7.867449 , 7.8662043, 7.8646903, 7.8629084, 7.860858 ,\n", " 7.8585405, 7.8559637, 7.853135 , 7.8500714, 7.8467846,\n", " 7.843305 , 7.8396564, 7.8358707, 7.8319883, 7.828048 ,\n", " 7.8240952, 7.8201776, 7.8163476, 7.812655 , 7.809155 ,\n", " 7.805899 , 7.8029428, 7.800338 , 7.7981334, 7.7963796,\n", " 7.795119 , 7.7943954, 7.7942476, 7.7947106, 7.795813 ,\n", " 7.797586 , 7.8000493, 7.8032246, 7.807126 , 7.811767 ,\n", " 7.8171563, 7.8233 , 7.830198 , 7.837856 , 7.846267 ], dtype=float32)" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "minimums" ] }, { "cell_type": "code", "execution_count": 30, "id": "92a900e8", "metadata": {}, "outputs": [], "source": [ "arglist = nanargmin(minimums)" ] }, { "cell_type": "code", "execution_count": 31, "id": "a520c53f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "37" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "arglist" ] }, { "cell_type": "code", "execution_count": 32, "id": "e5da0824", "metadata": {}, "outputs": [], "source": [ "argmin = domain[arglist]" ] }, { "cell_type": "code", "execution_count": 33, "id": "d2d53cec", "metadata": {}, "outputs": [], "source": [ "minimum = minimums[arglist]" ] }, { "cell_type": "code", "execution_count": 34, "id": "930267ad", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the minimum is 7.794247627258301 the arg min is 4.505731105804443\n" ] } ], "source": [ "print(f\"the minimum is {minimum} the arg min is {argmin}\")" ] }, { "cell_type": "markdown", "id": "c7910312", "metadata": {}, "source": [ "# Newton's method\n", "\n", "Solve for where $f(x) = 0$\n", "\n", "$$\n", " x_{n+1} = x_n - \\frac{f(x_n)}{f^\\prime(x_n)}\n", "$$\n", "\n", "But we search for $f^\\prime(x) = 0$\n", "\n", "$$\n", " x_{n+1} = x_n - \\frac{f^{\\prime}(x_n)}{f^{\\prime\\prime}(x_n)}\n", " $$" ] }, { "cell_type": "code", "execution_count": 35, "id": "426a3238", "metadata": {}, "outputs": [], "source": [ "gradL = grad(L)" ] }, { "cell_type": "code", "execution_count": 36, "id": "21cd42fa", "metadata": {}, "outputs": [], "source": [ "gradL2 = grad(gradL)" ] }, { "cell_type": "code", "execution_count": 39, "id": "4c5b3265", "metadata": {}, "outputs": [], "source": [ "def min_newton(x): return x - gradL(x) / gradL2(x)" ] }, { "cell_type": "code", "execution_count": 38, "id": "4dd342ac", "metadata": {}, "outputs": [], "source": [ "domain = np.linspace(3.0, 5.0, num=50)" ] }, { "cell_type": "code", "execution_count": 40, "id": "235d15b4", "metadata": {}, "outputs": [], "source": [ "f_newton = vmap(min_newton)" ] }, { "cell_type": "code", "execution_count": 41, "id": "9a88cd8b", "metadata": {}, "outputs": [], "source": [ "for epoch in range(50):\n", " domain = f_newton(domain)" ] }, { "cell_type": "code", "execution_count": 43, "id": "529dbb0e", "metadata": {}, "outputs": [], "source": [ "mins = minfun(domain)" ] }, { "cell_type": "code", "execution_count": 44, "id": "65aa7311", "metadata": {}, "outputs": [], "source": [ "arglist = nanargmin(minimums)\n", "argmin = domain[arglist]\n", "minimum = minimums[arglist]" ] }, { "cell_type": "code", "execution_count": 45, "id": "c0327e0e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the minimum is 7.794247627258301 the arg min is 4.499999523162842\n" ] } ], "source": [ "print(f\"the minimum is {minimum} the arg min is {argmin}\")" ] }, { "cell_type": "markdown", "id": "1180ee4c", "metadata": {}, "source": [ "# Multivariable Optimization\n", "\n", "$$\n", " X_{n+1} = X_n - 0.01 \\nabla f(X_n)\n", "$$" ] }, { "cell_type": "code", "execution_count": 48, "id": "2a968948", "metadata": {}, "outputs": [], "source": [ "def paraboloid(x): return (x[0] * x[1] - 2) ** 2 + (x[1] - 3) ** 2" ] }, { "cell_type": "code", "execution_count": 49, "id": "f893af70", "metadata": {}, "outputs": [], "source": [ "f_min = vmap(paraboloid)" ] }, { "cell_type": "code", "execution_count": 50, "id": "0b035649", "metadata": {}, "outputs": [], "source": [ "J = jacfwd(paraboloid)" ] }, { "cell_type": "code", "execution_count": 51, "id": "edd62878", "metadata": {}, "outputs": [], "source": [ "def jacobian_min(x): return x - 0.01 * J(x)" ] }, { "cell_type": "code", "execution_count": 69, "id": "81c302fd", "metadata": {}, "outputs": [], "source": [ "domain = random.uniform(key, shape=(50, 2), dtype=\"float32\", minval=-5.0, maxval=5.0)" ] }, { "cell_type": "code", "execution_count": 70, "id": "70d0555a", "metadata": {}, "outputs": [], "source": [ "v_jacobian_min = vmap(jacobian_min)" ] }, { "cell_type": "code", "execution_count": 71, "id": "d53b2818", "metadata": {}, "outputs": [], "source": [ "for epoch in range(150):\n", " domain = v_jacobian_min(domain)" ] }, { "cell_type": "code", "execution_count": 72, "id": "93c3cfdf", "metadata": {}, "outputs": [], "source": [ "mins = f_min(domain)" ] }, { "cell_type": "code", "execution_count": 73, "id": "c2f9afa2", "metadata": {}, "outputs": [], "source": [ "arglist = nanargmin(mins)\n", "argmin = domain[arglist]\n", "minimum = minimums[arglist]" ] }, { "cell_type": "code", "execution_count": 74, "id": "91ebaa54", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the minimum is 7.794247627258301 the arg min is [-0.54916656 0.60931224]\n" ] } ], "source": [ "print(f\"the minimum is {minimum} the arg min is {argmin}\")" ] }, { "cell_type": "markdown", "id": "e5e16d79", "metadata": {}, "source": [ "# The Hessian\n", "\n", "Similar to the Newton method:\n", "$$\n", " X_{n+1} = X_n - H^{-1}(X_n) \\nabla f(X_n)\n", "$$" ] }, { "cell_type": "code", "execution_count": 60, "id": "0c11e322", "metadata": {}, "outputs": [], "source": [ "def hessian(f): return jacfwd(jacrev(f))" ] }, { "cell_type": "code", "execution_count": 61, "id": "7a5d3955", "metadata": {}, "outputs": [], "source": [ "H = hessian(paraboloid)" ] }, { "cell_type": "code", "execution_count": 62, "id": "8b32404d", "metadata": {}, "outputs": [], "source": [ "def hessian_min(x): return x - 0.01 * linalg.inv(H(x)) @ J(x)" ] }, { "cell_type": "code", "execution_count": 75, "id": "050f46e5", "metadata": {}, "outputs": [], "source": [ "domain = random.uniform(key, shape=(50, 2), dtype=\"float32\", minval=-5.0, maxval=5.0)" ] }, { "cell_type": "code", "execution_count": 76, "id": "cba5bd2b", "metadata": {}, "outputs": [], "source": [ "v_hessian_min = vmap(hessian_min)" ] }, { "cell_type": "code", "execution_count": 77, "id": "27b219df", "metadata": {}, "outputs": [], "source": [ "for epoch in range(150):\n", " domain = v_hessian_min(domain)" ] }, { "cell_type": "code", "execution_count": 78, "id": "dfdf027b", "metadata": {}, "outputs": [], "source": [ "mins = f_min(domain)" ] }, { "cell_type": "code", "execution_count": 79, "id": "db074a43", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the minimum is 7.794247627258301 the arg min is [-0.74005073 -0.00774691]\n" ] } ], "source": [ "arglist = nanargmin(mins)\n", "argmin = domain[arglist]\n", "minimum = minimums[arglist]\n", "print(f\"the minimum is {minimum} the arg min is {argmin}\")" ] }, { "cell_type": "markdown", "id": "e2cc6b7d", "metadata": {}, "source": [ "# Mutlivariable Constrained Optimization\n", "\n", "Lagrangian multipliers:\n", "\n", "$$\n", "\\begin{align}\n", "\\nabla f = \\lambda \\nabla g\\\\\n", "\\nabla f - \\lambda \\nabla g = 0\\\\\n", "\\nabla (f - \\lambda g) = 0\\\\\n", "\\nabla (\\mathcal{L}) = 0\\\\\n", "\\end{align}\n", "$$\n", "\n", "Our iterative approach will be:\n", "\n", "$$\n", "X_{n+1} = X_n - \\nabla^2\\mathcal{L}^{-1}(X_n)\\nabla\\mathcal{L}(X_n)\n", "$$" ] }, { "cell_type": "code", "execution_count": 80, "id": "a8ff98bc", "metadata": {}, "outputs": [], "source": [ "def f(x): return 4 * (x[0] ** 2) * x[1]\n", "def g(x): return x[0]**2 + x[1]**2 - 3" ] }, { "cell_type": "code", "execution_count": 81, "id": "ba2ce45c", "metadata": {}, "outputs": [], "source": [ "f_min = vmap(f)" ] }, { "cell_type": "code", "execution_count": 82, "id": "c0d03ac0", "metadata": {}, "outputs": [], "source": [ "def lagrange(l): return f(l[0:2]) - l[2]*g(l[0:2])" ] }, { "cell_type": "code", "execution_count": 83, "id": "fc1d5af3", "metadata": {}, "outputs": [], "source": [ "L = jacfwd(lagrange)" ] }, { "cell_type": "code", "execution_count": 84, "id": "60f8134b", "metadata": {}, "outputs": [], "source": [ "gradL = jacfwd(L)" ] }, { "cell_type": "code", "execution_count": 85, "id": "ef6ddc53", "metadata": {}, "outputs": [], "source": [ "def lagrange_min(l): return l - linalg.inv(gradL(l)) @ L(l)" ] }, { "cell_type": "code", "execution_count": 86, "id": "c67b6785", "metadata": {}, "outputs": [], "source": [ "domain = random.uniform(key, shape=(50,3), dtype=\"float32\", minval=-5.0, maxval=5.0)" ] }, { "cell_type": "code", "execution_count": 87, "id": "d61873ab", "metadata": {}, "outputs": [], "source": [ "v_lagrange_min = vmap(lagrange_min)" ] }, { "cell_type": "code", "execution_count": 88, "id": "73539993", "metadata": {}, "outputs": [], "source": [ "for epoch in range(150):\n", " domain = v_lagrange_min(domain)" ] }, { "cell_type": "code", "execution_count": 89, "id": "7f4d91c7", "metadata": {}, "outputs": [], "source": [ "mins = f_min(domain)" ] }, { "cell_type": "code", "execution_count": 91, "id": "38f56026", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the minimum is 7.8554582595825195 the arg min is [-1.4142135 -1. -4. ]\n" ] } ], "source": [ "arglist = nanargmin(mins)\n", "argmin = domain[arglist]\n", "minimum = minimums[arglist]\n", "print(f\"the minimum is {minimum} the arg min is {argmin}\")" ] }, { "cell_type": "markdown", "id": "380c75c3", "metadata": {}, "source": [ "# Three Variable Multivariable Constrained Optimization\n", "$$\n", "\\begin{align}\n", "f(x, y, z)& = xyz\\\\\n", "g(x, y, z)& = 2xy + 2yz + 2xz - 64\\\\\n", "\\end{align}\n", "$$\n", "\n", "Our iterative approach will be:\n", "\n", "$$\n", "X_{n+1} = X_n - 0.1 \\nabla^2\\mathcal{L}^{-1}(X_n)\\nabla\\mathcal{L}(X_n)\n", "$$" ] }, { "cell_type": "code", "execution_count": 98, "id": "820312ab", "metadata": {}, "outputs": [], "source": [ "def f(x): return x[0]*x[1]*x[2]\n", "def g(x): return 2*x[0]*x[1] + 2*x[1]*x[2] + 2*x[0]*x[2] - 64\n", "\n", "f_min = vmap(f)\n", "\n", "def lagrange(l): return f(l[0:3]) - l[3]*g(l[0:3])\n", "\n", "L = jacfwd(lagrange)\n", "gradL = jacfwd(L)" ] }, { "cell_type": "code", "execution_count": 99, "id": "61f9db1e", "metadata": {}, "outputs": [], "source": [ "def lagrangian_min(l): return l - 0.1 * linalg.inv(gradL(l)) @ L(l)" ] }, { "cell_type": "code", "execution_count": 100, "id": "9e8d959f", "metadata": {}, "outputs": [], "source": [ "domain = random.uniform(key, shape=(50, 4), dtype=\"float32\", minval = 0, maxval=10)" ] }, { "cell_type": "code", "execution_count": 101, "id": "c3b8c4fa", "metadata": {}, "outputs": [], "source": [ "v_lagrangian_min = vmap(lagrangian_min)" ] }, { "cell_type": "code", "execution_count": 102, "id": "8e83319e", "metadata": {}, "outputs": [], "source": [ "for epoch in range(200):\n", " domain = v_lagrangian_min(domain)" ] }, { "cell_type": "code", "execution_count": 103, "id": "d5f28369", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The minimum is 34.83720016479492, the argmin is (3.2659873962402344,3.2659854888916016,3.2659873962402344), the lagrangian is 0.8164968490600586\n" ] } ], "source": [ "maximums = f_min(domain)\n", "\n", "arglist = nanargmax(maximums)\n", "argmin = domain[arglist]\n", "minimum = maximums[arglist]\n", "\n", "print(\"The minimum is {}, the argmin is ({},{},{}), the lagrangian is {}\".format(minimum,\n", " argmin[0], argmin[1], argmin[2], argmin[3]))" ] }, { "cell_type": "markdown", "id": "0bf20ea5", "metadata": {}, "source": [ "# Multivariable MultiConstrained Optimization" ] }, { "cell_type": "code", "execution_count": null, "id": "e834f8ba", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.4" } }, "nbformat": 4, "nbformat_minor": 5 }